statement.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "bytes"
  7. "database/sql/driver"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "github.com/go-xorm/builder"
  15. "github.com/go-xorm/core"
  16. )
  17. // Statement save all the sql info for executing SQL
  18. type Statement struct {
  19. RefTable *core.Table
  20. Engine *Engine
  21. Start int
  22. LimitN int
  23. idParam *core.PK
  24. OrderStr string
  25. JoinStr string
  26. joinArgs []interface{}
  27. GroupByStr string
  28. HavingStr string
  29. ColumnStr string
  30. selectStr string
  31. useAllCols bool
  32. OmitStr string
  33. AltTableName string
  34. tableName string
  35. RawSQL string
  36. RawParams []interface{}
  37. UseCascade bool
  38. UseAutoJoin bool
  39. StoreEngine string
  40. Charset string
  41. UseCache bool
  42. UseAutoTime bool
  43. noAutoCondition bool
  44. IsDistinct bool
  45. IsForUpdate bool
  46. TableAlias string
  47. allUseBool bool
  48. checkVersion bool
  49. unscoped bool
  50. columnMap columnMap
  51. omitColumnMap columnMap
  52. mustColumnMap map[string]bool
  53. nullableMap map[string]bool
  54. incrColumns map[string]incrParam
  55. decrColumns map[string]decrParam
  56. exprColumns map[string]exprParam
  57. cond builder.Cond
  58. bufferSize int
  59. }
  60. // Init reset all the statement's fields
  61. func (statement *Statement) Init() {
  62. statement.RefTable = nil
  63. statement.Start = 0
  64. statement.LimitN = 0
  65. statement.OrderStr = ""
  66. statement.UseCascade = true
  67. statement.JoinStr = ""
  68. statement.joinArgs = make([]interface{}, 0)
  69. statement.GroupByStr = ""
  70. statement.HavingStr = ""
  71. statement.ColumnStr = ""
  72. statement.OmitStr = ""
  73. statement.columnMap = columnMap{}
  74. statement.omitColumnMap = columnMap{}
  75. statement.AltTableName = ""
  76. statement.tableName = ""
  77. statement.idParam = nil
  78. statement.RawSQL = ""
  79. statement.RawParams = make([]interface{}, 0)
  80. statement.UseCache = true
  81. statement.UseAutoTime = true
  82. statement.noAutoCondition = false
  83. statement.IsDistinct = false
  84. statement.IsForUpdate = false
  85. statement.TableAlias = ""
  86. statement.selectStr = ""
  87. statement.allUseBool = false
  88. statement.useAllCols = false
  89. statement.mustColumnMap = make(map[string]bool)
  90. statement.nullableMap = make(map[string]bool)
  91. statement.checkVersion = true
  92. statement.unscoped = false
  93. statement.incrColumns = make(map[string]incrParam)
  94. statement.decrColumns = make(map[string]decrParam)
  95. statement.exprColumns = make(map[string]exprParam)
  96. statement.cond = builder.NewCond()
  97. statement.bufferSize = 0
  98. }
  99. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  100. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  101. statement.noAutoCondition = true
  102. if len(no) > 0 {
  103. statement.noAutoCondition = no[0]
  104. }
  105. return statement
  106. }
  107. // Alias set the table alias
  108. func (statement *Statement) Alias(alias string) *Statement {
  109. statement.TableAlias = alias
  110. return statement
  111. }
  112. // SQL adds raw sql statement
  113. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  114. switch query.(type) {
  115. case (*builder.Builder):
  116. var err error
  117. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  118. if err != nil {
  119. statement.Engine.logger.Error(err)
  120. }
  121. case string:
  122. statement.RawSQL = query.(string)
  123. statement.RawParams = args
  124. default:
  125. statement.Engine.logger.Error("unsupported sql type")
  126. }
  127. return statement
  128. }
  129. // Where add Where statement
  130. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  131. return statement.And(query, args...)
  132. }
  133. // And add Where & and statement
  134. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  135. switch query.(type) {
  136. case string:
  137. cond := builder.Expr(query.(string), args...)
  138. statement.cond = statement.cond.And(cond)
  139. case map[string]interface{}:
  140. cond := builder.Eq(query.(map[string]interface{}))
  141. statement.cond = statement.cond.And(cond)
  142. case builder.Cond:
  143. cond := query.(builder.Cond)
  144. statement.cond = statement.cond.And(cond)
  145. for _, v := range args {
  146. if vv, ok := v.(builder.Cond); ok {
  147. statement.cond = statement.cond.And(vv)
  148. }
  149. }
  150. default:
  151. // TODO: not support condition type
  152. }
  153. return statement
  154. }
  155. // Or add Where & Or statement
  156. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  157. switch query.(type) {
  158. case string:
  159. cond := builder.Expr(query.(string), args...)
  160. statement.cond = statement.cond.Or(cond)
  161. case map[string]interface{}:
  162. cond := builder.Eq(query.(map[string]interface{}))
  163. statement.cond = statement.cond.Or(cond)
  164. case builder.Cond:
  165. cond := query.(builder.Cond)
  166. statement.cond = statement.cond.Or(cond)
  167. for _, v := range args {
  168. if vv, ok := v.(builder.Cond); ok {
  169. statement.cond = statement.cond.Or(vv)
  170. }
  171. }
  172. default:
  173. // TODO: not support condition type
  174. }
  175. return statement
  176. }
  177. // In generate "Where column IN (?) " statement
  178. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  179. in := builder.In(statement.Engine.Quote(column), args...)
  180. statement.cond = statement.cond.And(in)
  181. return statement
  182. }
  183. // NotIn generate "Where column NOT IN (?) " statement
  184. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  185. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  186. statement.cond = statement.cond.And(notIn)
  187. return statement
  188. }
  189. func (statement *Statement) setRefValue(v reflect.Value) error {
  190. var err error
  191. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  192. if err != nil {
  193. return err
  194. }
  195. statement.tableName = statement.Engine.TableName(v, true)
  196. return nil
  197. }
  198. func (statement *Statement) setRefBean(bean interface{}) error {
  199. var err error
  200. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  201. if err != nil {
  202. return err
  203. }
  204. statement.tableName = statement.Engine.TableName(bean, true)
  205. return nil
  206. }
  207. // Auto generating update columnes and values according a struct
  208. func (statement *Statement) buildUpdates(bean interface{},
  209. includeVersion, includeUpdated, includeNil,
  210. includeAutoIncr, update bool) ([]string, []interface{}) {
  211. engine := statement.Engine
  212. table := statement.RefTable
  213. allUseBool := statement.allUseBool
  214. useAllCols := statement.useAllCols
  215. mustColumnMap := statement.mustColumnMap
  216. nullableMap := statement.nullableMap
  217. columnMap := statement.columnMap
  218. omitColumnMap := statement.omitColumnMap
  219. unscoped := statement.unscoped
  220. var colNames = make([]string, 0)
  221. var args = make([]interface{}, 0)
  222. for _, col := range table.Columns() {
  223. if !includeVersion && col.IsVersion {
  224. continue
  225. }
  226. if col.IsCreated {
  227. continue
  228. }
  229. if !includeUpdated && col.IsUpdated {
  230. continue
  231. }
  232. if !includeAutoIncr && col.IsAutoIncrement {
  233. continue
  234. }
  235. if col.IsDeleted && !unscoped {
  236. continue
  237. }
  238. if omitColumnMap.contain(col.Name) {
  239. continue
  240. }
  241. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  242. continue
  243. }
  244. fieldValuePtr, err := col.ValueOf(bean)
  245. if err != nil {
  246. engine.logger.Error(err)
  247. continue
  248. }
  249. fieldValue := *fieldValuePtr
  250. fieldType := reflect.TypeOf(fieldValue.Interface())
  251. if fieldType == nil {
  252. continue
  253. }
  254. requiredField := useAllCols
  255. includeNil := useAllCols
  256. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  257. if b {
  258. requiredField = true
  259. } else {
  260. continue
  261. }
  262. }
  263. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  264. if b, ok := getFlagForColumn(nullableMap, col); ok {
  265. if b && col.Nullable && isZero(fieldValue.Interface()) {
  266. var nilValue *int
  267. fieldValue = reflect.ValueOf(nilValue)
  268. fieldType = reflect.TypeOf(fieldValue.Interface())
  269. includeNil = true
  270. }
  271. }
  272. var val interface{}
  273. if fieldValue.CanAddr() {
  274. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  275. data, err := structConvert.ToDB()
  276. if err != nil {
  277. engine.logger.Error(err)
  278. } else {
  279. val = data
  280. }
  281. goto APPEND
  282. }
  283. }
  284. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  285. data, err := structConvert.ToDB()
  286. if err != nil {
  287. engine.logger.Error(err)
  288. } else {
  289. val = data
  290. }
  291. goto APPEND
  292. }
  293. if fieldType.Kind() == reflect.Ptr {
  294. if fieldValue.IsNil() {
  295. if includeNil {
  296. args = append(args, nil)
  297. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  298. }
  299. continue
  300. } else if !fieldValue.IsValid() {
  301. continue
  302. } else {
  303. // dereference ptr type to instance type
  304. fieldValue = fieldValue.Elem()
  305. fieldType = reflect.TypeOf(fieldValue.Interface())
  306. requiredField = true
  307. }
  308. }
  309. switch fieldType.Kind() {
  310. case reflect.Bool:
  311. if allUseBool || requiredField {
  312. val = fieldValue.Interface()
  313. } else {
  314. // if a bool in a struct, it will not be as a condition because it default is false,
  315. // please use Where() instead
  316. continue
  317. }
  318. case reflect.String:
  319. if !requiredField && fieldValue.String() == "" {
  320. continue
  321. }
  322. // for MyString, should convert to string or panic
  323. if fieldType.String() != reflect.String.String() {
  324. val = fieldValue.String()
  325. } else {
  326. val = fieldValue.Interface()
  327. }
  328. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  329. if !requiredField && fieldValue.Int() == 0 {
  330. continue
  331. }
  332. val = fieldValue.Interface()
  333. case reflect.Float32, reflect.Float64:
  334. if !requiredField && fieldValue.Float() == 0.0 {
  335. continue
  336. }
  337. val = fieldValue.Interface()
  338. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  339. if !requiredField && fieldValue.Uint() == 0 {
  340. continue
  341. }
  342. t := int64(fieldValue.Uint())
  343. val = reflect.ValueOf(&t).Interface()
  344. case reflect.Struct:
  345. if fieldType.ConvertibleTo(core.TimeType) {
  346. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  347. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  348. continue
  349. }
  350. val = engine.formatColTime(col, t)
  351. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  352. val, _ = nulType.Value()
  353. } else {
  354. if !col.SQLType.IsJson() {
  355. engine.autoMapType(fieldValue)
  356. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  357. if len(table.PrimaryKeys) == 1 {
  358. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  359. // fix non-int pk issues
  360. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  361. val = pkField.Interface()
  362. } else {
  363. continue
  364. }
  365. } else {
  366. //TODO: how to handler?
  367. panic("not supported")
  368. }
  369. } else {
  370. val = fieldValue.Interface()
  371. }
  372. } else {
  373. // Blank struct could not be as update data
  374. if requiredField || !isStructZero(fieldValue) {
  375. bytes, err := json.Marshal(fieldValue.Interface())
  376. if err != nil {
  377. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  378. }
  379. if col.SQLType.IsText() {
  380. val = string(bytes)
  381. } else if col.SQLType.IsBlob() {
  382. val = bytes
  383. }
  384. } else {
  385. continue
  386. }
  387. }
  388. }
  389. case reflect.Array, reflect.Slice, reflect.Map:
  390. if !requiredField {
  391. if fieldValue == reflect.Zero(fieldType) {
  392. continue
  393. }
  394. if fieldType.Kind() == reflect.Array {
  395. if isArrayValueZero(fieldValue) {
  396. continue
  397. }
  398. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  399. continue
  400. }
  401. }
  402. if col.SQLType.IsText() {
  403. bytes, err := json.Marshal(fieldValue.Interface())
  404. if err != nil {
  405. engine.logger.Error(err)
  406. continue
  407. }
  408. val = string(bytes)
  409. } else if col.SQLType.IsBlob() {
  410. var bytes []byte
  411. var err error
  412. if fieldType.Kind() == reflect.Slice &&
  413. fieldType.Elem().Kind() == reflect.Uint8 {
  414. if fieldValue.Len() > 0 {
  415. val = fieldValue.Bytes()
  416. } else {
  417. continue
  418. }
  419. } else if fieldType.Kind() == reflect.Array &&
  420. fieldType.Elem().Kind() == reflect.Uint8 {
  421. val = fieldValue.Slice(0, 0).Interface()
  422. } else {
  423. bytes, err = json.Marshal(fieldValue.Interface())
  424. if err != nil {
  425. engine.logger.Error(err)
  426. continue
  427. }
  428. val = bytes
  429. }
  430. } else {
  431. continue
  432. }
  433. default:
  434. val = fieldValue.Interface()
  435. }
  436. APPEND:
  437. args = append(args, val)
  438. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  439. continue
  440. }
  441. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  442. }
  443. return colNames, args
  444. }
  445. func (statement *Statement) needTableName() bool {
  446. return len(statement.JoinStr) > 0
  447. }
  448. func (statement *Statement) colName(col *core.Column, tableName string) string {
  449. if statement.needTableName() {
  450. var nm = tableName
  451. if len(statement.TableAlias) > 0 {
  452. nm = statement.TableAlias
  453. }
  454. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  455. }
  456. return statement.Engine.Quote(col.Name)
  457. }
  458. // TableName return current tableName
  459. func (statement *Statement) TableName() string {
  460. if statement.AltTableName != "" {
  461. return statement.AltTableName
  462. }
  463. return statement.tableName
  464. }
  465. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  466. func (statement *Statement) ID(id interface{}) *Statement {
  467. idValue := reflect.ValueOf(id)
  468. idType := reflect.TypeOf(idValue.Interface())
  469. switch idType {
  470. case ptrPkType:
  471. if pkPtr, ok := (id).(*core.PK); ok {
  472. statement.idParam = pkPtr
  473. return statement
  474. }
  475. case pkType:
  476. if pk, ok := (id).(core.PK); ok {
  477. statement.idParam = &pk
  478. return statement
  479. }
  480. }
  481. switch idType.Kind() {
  482. case reflect.String:
  483. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  484. return statement
  485. }
  486. statement.idParam = &core.PK{id}
  487. return statement
  488. }
  489. // Incr Generate "Update ... Set column = column + arg" statement
  490. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  491. k := strings.ToLower(column)
  492. if len(arg) > 0 {
  493. statement.incrColumns[k] = incrParam{column, arg[0]}
  494. } else {
  495. statement.incrColumns[k] = incrParam{column, 1}
  496. }
  497. return statement
  498. }
  499. // Decr Generate "Update ... Set column = column - arg" statement
  500. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  501. k := strings.ToLower(column)
  502. if len(arg) > 0 {
  503. statement.decrColumns[k] = decrParam{column, arg[0]}
  504. } else {
  505. statement.decrColumns[k] = decrParam{column, 1}
  506. }
  507. return statement
  508. }
  509. // SetExpr Generate "Update ... Set column = {expression}" statement
  510. func (statement *Statement) SetExpr(column string, expression string) *Statement {
  511. k := strings.ToLower(column)
  512. statement.exprColumns[k] = exprParam{column, expression}
  513. return statement
  514. }
  515. // Generate "Update ... Set column = column + arg" statement
  516. func (statement *Statement) getInc() map[string]incrParam {
  517. return statement.incrColumns
  518. }
  519. // Generate "Update ... Set column = column - arg" statement
  520. func (statement *Statement) getDec() map[string]decrParam {
  521. return statement.decrColumns
  522. }
  523. // Generate "Update ... Set column = {expression}" statement
  524. func (statement *Statement) getExpr() map[string]exprParam {
  525. return statement.exprColumns
  526. }
  527. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  528. newColumns := make([]string, 0)
  529. for _, col := range columns {
  530. col = strings.Replace(col, "`", "", -1)
  531. col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
  532. ccols := strings.Split(col, ",")
  533. for _, c := range ccols {
  534. fields := strings.Split(strings.TrimSpace(c), ".")
  535. if len(fields) == 1 {
  536. newColumns = append(newColumns, statement.Engine.quote(fields[0]))
  537. } else if len(fields) == 2 {
  538. newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
  539. statement.Engine.quote(fields[1]))
  540. } else {
  541. panic(errors.New("unwanted colnames"))
  542. }
  543. }
  544. }
  545. return newColumns
  546. }
  547. func (statement *Statement) colmap2NewColsWithQuote() []string {
  548. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  549. copy(newColumns, statement.columnMap)
  550. for i := 0; i < len(statement.columnMap); i++ {
  551. newColumns[i] = statement.Engine.Quote(newColumns[i])
  552. }
  553. return newColumns
  554. }
  555. // Distinct generates "DISTINCT col1, col2 " statement
  556. func (statement *Statement) Distinct(columns ...string) *Statement {
  557. statement.IsDistinct = true
  558. statement.Cols(columns...)
  559. return statement
  560. }
  561. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  562. func (statement *Statement) ForUpdate() *Statement {
  563. statement.IsForUpdate = true
  564. return statement
  565. }
  566. // Select replace select
  567. func (statement *Statement) Select(str string) *Statement {
  568. statement.selectStr = str
  569. return statement
  570. }
  571. // Cols generate "col1, col2" statement
  572. func (statement *Statement) Cols(columns ...string) *Statement {
  573. cols := col2NewCols(columns...)
  574. for _, nc := range cols {
  575. statement.columnMap.add(nc)
  576. }
  577. newColumns := statement.colmap2NewColsWithQuote()
  578. statement.ColumnStr = strings.Join(newColumns, ", ")
  579. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  580. return statement
  581. }
  582. // AllCols update use only: update all columns
  583. func (statement *Statement) AllCols() *Statement {
  584. statement.useAllCols = true
  585. return statement
  586. }
  587. // MustCols update use only: must update columns
  588. func (statement *Statement) MustCols(columns ...string) *Statement {
  589. newColumns := col2NewCols(columns...)
  590. for _, nc := range newColumns {
  591. statement.mustColumnMap[strings.ToLower(nc)] = true
  592. }
  593. return statement
  594. }
  595. // UseBool indicates that use bool fields as update contents and query contiditions
  596. func (statement *Statement) UseBool(columns ...string) *Statement {
  597. if len(columns) > 0 {
  598. statement.MustCols(columns...)
  599. } else {
  600. statement.allUseBool = true
  601. }
  602. return statement
  603. }
  604. // Omit do not use the columns
  605. func (statement *Statement) Omit(columns ...string) {
  606. newColumns := col2NewCols(columns...)
  607. for _, nc := range newColumns {
  608. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  609. }
  610. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  611. }
  612. // Nullable Update use only: update columns to null when value is nullable and zero-value
  613. func (statement *Statement) Nullable(columns ...string) {
  614. newColumns := col2NewCols(columns...)
  615. for _, nc := range newColumns {
  616. statement.nullableMap[strings.ToLower(nc)] = true
  617. }
  618. }
  619. // Top generate LIMIT limit statement
  620. func (statement *Statement) Top(limit int) *Statement {
  621. statement.Limit(limit)
  622. return statement
  623. }
  624. // Limit generate LIMIT start, limit statement
  625. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  626. statement.LimitN = limit
  627. if len(start) > 0 {
  628. statement.Start = start[0]
  629. }
  630. return statement
  631. }
  632. // OrderBy generate "Order By order" statement
  633. func (statement *Statement) OrderBy(order string) *Statement {
  634. if len(statement.OrderStr) > 0 {
  635. statement.OrderStr += ", "
  636. }
  637. statement.OrderStr += order
  638. return statement
  639. }
  640. // Desc generate `ORDER BY xx DESC`
  641. func (statement *Statement) Desc(colNames ...string) *Statement {
  642. var buf bytes.Buffer
  643. fmt.Fprintf(&buf, statement.OrderStr)
  644. if len(statement.OrderStr) > 0 {
  645. fmt.Fprint(&buf, ", ")
  646. }
  647. newColNames := statement.col2NewColsWithQuote(colNames...)
  648. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  649. statement.OrderStr = buf.String()
  650. return statement
  651. }
  652. // Asc provide asc order by query condition, the input parameters are columns.
  653. func (statement *Statement) Asc(colNames ...string) *Statement {
  654. var buf bytes.Buffer
  655. fmt.Fprintf(&buf, statement.OrderStr)
  656. if len(statement.OrderStr) > 0 {
  657. fmt.Fprint(&buf, ", ")
  658. }
  659. newColNames := statement.col2NewColsWithQuote(colNames...)
  660. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  661. statement.OrderStr = buf.String()
  662. return statement
  663. }
  664. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  665. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  666. v := rValue(tableNameOrBean)
  667. t := v.Type()
  668. if t.Kind() == reflect.Struct {
  669. var err error
  670. statement.RefTable, err = statement.Engine.autoMapType(v)
  671. if err != nil {
  672. statement.Engine.logger.Error(err)
  673. return statement
  674. }
  675. }
  676. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  677. return statement
  678. }
  679. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  680. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  681. var buf bytes.Buffer
  682. if len(statement.JoinStr) > 0 {
  683. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  684. } else {
  685. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  686. }
  687. tbName := statement.Engine.TableName(tablename, true)
  688. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  689. statement.JoinStr = buf.String()
  690. statement.joinArgs = append(statement.joinArgs, args...)
  691. return statement
  692. }
  693. // GroupBy generate "Group By keys" statement
  694. func (statement *Statement) GroupBy(keys string) *Statement {
  695. statement.GroupByStr = keys
  696. return statement
  697. }
  698. // Having generate "Having conditions" statement
  699. func (statement *Statement) Having(conditions string) *Statement {
  700. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  701. return statement
  702. }
  703. // Unscoped always disable struct tag "deleted"
  704. func (statement *Statement) Unscoped() *Statement {
  705. statement.unscoped = true
  706. return statement
  707. }
  708. func (statement *Statement) genColumnStr() string {
  709. var buf bytes.Buffer
  710. if statement.RefTable == nil {
  711. return ""
  712. }
  713. columns := statement.RefTable.Columns()
  714. for _, col := range columns {
  715. if statement.omitColumnMap.contain(col.Name) {
  716. continue
  717. }
  718. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  719. continue
  720. }
  721. if col.MapType == core.ONLYTODB {
  722. continue
  723. }
  724. if buf.Len() != 0 {
  725. buf.WriteString(", ")
  726. }
  727. if statement.JoinStr != "" {
  728. if statement.TableAlias != "" {
  729. buf.WriteString(statement.TableAlias)
  730. } else {
  731. buf.WriteString(statement.TableName())
  732. }
  733. buf.WriteString(".")
  734. }
  735. statement.Engine.QuoteTo(&buf, col.Name)
  736. }
  737. return buf.String()
  738. }
  739. func (statement *Statement) genCreateTableSQL() string {
  740. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  741. statement.StoreEngine, statement.Charset)
  742. }
  743. func (statement *Statement) genIndexSQL() []string {
  744. var sqls []string
  745. tbName := statement.TableName()
  746. for _, index := range statement.RefTable.Indexes {
  747. if index.Type == core.IndexType {
  748. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  749. /*idxTBName := strings.Replace(tbName, ".", "_", -1)
  750. idxTBName = strings.Replace(idxTBName, `"`, "", -1)
  751. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
  752. quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
  753. sqls = append(sqls, sql)
  754. }
  755. }
  756. return sqls
  757. }
  758. func uniqueName(tableName, uqeName string) string {
  759. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  760. }
  761. func (statement *Statement) genUniqueSQL() []string {
  762. var sqls []string
  763. tbName := statement.TableName()
  764. for _, index := range statement.RefTable.Indexes {
  765. if index.Type == core.UniqueType {
  766. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  767. sqls = append(sqls, sql)
  768. }
  769. }
  770. return sqls
  771. }
  772. func (statement *Statement) genDelIndexSQL() []string {
  773. var sqls []string
  774. tbName := statement.TableName()
  775. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  776. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  777. for idxName, index := range statement.RefTable.Indexes {
  778. var rIdxName string
  779. if index.Type == core.UniqueType {
  780. rIdxName = uniqueName(idxPrefixName, idxName)
  781. } else if index.Type == core.IndexType {
  782. rIdxName = indexName(idxPrefixName, idxName)
  783. }
  784. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  785. if statement.Engine.dialect.IndexOnTable() {
  786. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  787. }
  788. sqls = append(sqls, sql)
  789. }
  790. return sqls
  791. }
  792. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  793. quote := statement.Engine.Quote
  794. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  795. col.String(statement.Engine.dialect))
  796. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  797. sql += " COMMENT '" + col.Comment + "'"
  798. }
  799. sql += ";"
  800. return sql, []interface{}{}
  801. }
  802. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  803. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  804. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  805. }
  806. func (statement *Statement) mergeConds(bean interface{}) error {
  807. if !statement.noAutoCondition {
  808. var addedTableName = (len(statement.JoinStr) > 0)
  809. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  810. if err != nil {
  811. return err
  812. }
  813. statement.cond = statement.cond.And(autoCond)
  814. }
  815. if err := statement.processIDParam(); err != nil {
  816. return err
  817. }
  818. return nil
  819. }
  820. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  821. if err := statement.mergeConds(bean); err != nil {
  822. return "", nil, err
  823. }
  824. return builder.ToSQL(statement.cond)
  825. }
  826. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  827. v := rValue(bean)
  828. isStruct := v.Kind() == reflect.Struct
  829. if isStruct {
  830. statement.setRefBean(bean)
  831. }
  832. var columnStr = statement.ColumnStr
  833. if len(statement.selectStr) > 0 {
  834. columnStr = statement.selectStr
  835. } else {
  836. // TODO: always generate column names, not use * even if join
  837. if len(statement.JoinStr) == 0 {
  838. if len(columnStr) == 0 {
  839. if len(statement.GroupByStr) > 0 {
  840. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  841. } else {
  842. columnStr = statement.genColumnStr()
  843. }
  844. }
  845. } else {
  846. if len(columnStr) == 0 {
  847. if len(statement.GroupByStr) > 0 {
  848. columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
  849. }
  850. }
  851. }
  852. }
  853. if len(columnStr) == 0 {
  854. columnStr = "*"
  855. }
  856. if isStruct {
  857. if err := statement.mergeConds(bean); err != nil {
  858. return "", nil, err
  859. }
  860. } else {
  861. if err := statement.processIDParam(); err != nil {
  862. return "", nil, err
  863. }
  864. }
  865. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  866. if err != nil {
  867. return "", nil, err
  868. }
  869. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  870. if err != nil {
  871. return "", nil, err
  872. }
  873. return sqlStr, append(statement.joinArgs, condArgs...), nil
  874. }
  875. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  876. var condSQL string
  877. var condArgs []interface{}
  878. var err error
  879. if len(beans) > 0 {
  880. statement.setRefBean(beans[0])
  881. condSQL, condArgs, err = statement.genConds(beans[0])
  882. } else {
  883. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  884. }
  885. if err != nil {
  886. return "", nil, err
  887. }
  888. var selectSQL = statement.selectStr
  889. if len(selectSQL) <= 0 {
  890. if statement.IsDistinct {
  891. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  892. } else {
  893. selectSQL = "count(*)"
  894. }
  895. }
  896. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  897. if err != nil {
  898. return "", nil, err
  899. }
  900. return sqlStr, append(statement.joinArgs, condArgs...), nil
  901. }
  902. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  903. statement.setRefBean(bean)
  904. var sumStrs = make([]string, 0, len(columns))
  905. for _, colName := range columns {
  906. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  907. colName = statement.Engine.Quote(colName)
  908. }
  909. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  910. }
  911. sumSelect := strings.Join(sumStrs, ", ")
  912. condSQL, condArgs, err := statement.genConds(bean)
  913. if err != nil {
  914. return "", nil, err
  915. }
  916. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  917. if err != nil {
  918. return "", nil, err
  919. }
  920. return sqlStr, append(statement.joinArgs, condArgs...), nil
  921. }
  922. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (a string, err error) {
  923. var distinct string
  924. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  925. distinct = "DISTINCT "
  926. }
  927. var dialect = statement.Engine.Dialect()
  928. var quote = statement.Engine.Quote
  929. var top string
  930. var mssqlCondi string
  931. var buf bytes.Buffer
  932. if len(condSQL) > 0 {
  933. fmt.Fprintf(&buf, " WHERE %v", condSQL)
  934. }
  935. var whereStr = buf.String()
  936. var fromStr = " FROM "
  937. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  938. fromStr += statement.TableName()
  939. } else {
  940. fromStr += quote(statement.TableName())
  941. }
  942. if statement.TableAlias != "" {
  943. if dialect.DBType() == core.ORACLE {
  944. fromStr += " " + quote(statement.TableAlias)
  945. } else {
  946. fromStr += " AS " + quote(statement.TableAlias)
  947. }
  948. }
  949. if statement.JoinStr != "" {
  950. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  951. }
  952. if dialect.DBType() == core.MSSQL {
  953. if statement.LimitN > 0 {
  954. top = fmt.Sprintf(" TOP %d ", statement.LimitN)
  955. }
  956. if statement.Start > 0 {
  957. var column string
  958. if len(statement.RefTable.PKColumns()) == 0 {
  959. for _, index := range statement.RefTable.Indexes {
  960. if len(index.Cols) == 1 {
  961. column = index.Cols[0]
  962. break
  963. }
  964. }
  965. if len(column) == 0 {
  966. column = statement.RefTable.ColumnsSeq()[0]
  967. }
  968. } else {
  969. column = statement.RefTable.PKColumns()[0].Name
  970. }
  971. if statement.needTableName() {
  972. if len(statement.TableAlias) > 0 {
  973. column = statement.TableAlias + "." + column
  974. } else {
  975. column = statement.TableName() + "." + column
  976. }
  977. }
  978. var orderStr string
  979. if needOrderBy && len(statement.OrderStr) > 0 {
  980. orderStr = " ORDER BY " + statement.OrderStr
  981. }
  982. var groupStr string
  983. if len(statement.GroupByStr) > 0 {
  984. groupStr = " GROUP BY " + statement.GroupByStr
  985. }
  986. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  987. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  988. }
  989. }
  990. // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
  991. a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  992. if len(mssqlCondi) > 0 {
  993. if len(whereStr) > 0 {
  994. a += " AND " + mssqlCondi
  995. } else {
  996. a += " WHERE " + mssqlCondi
  997. }
  998. }
  999. if statement.GroupByStr != "" {
  1000. a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr)
  1001. }
  1002. if statement.HavingStr != "" {
  1003. a = fmt.Sprintf("%v %v", a, statement.HavingStr)
  1004. }
  1005. if needOrderBy && statement.OrderStr != "" {
  1006. a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
  1007. }
  1008. if needLimit {
  1009. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1010. if statement.Start > 0 {
  1011. a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
  1012. } else if statement.LimitN > 0 {
  1013. a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
  1014. }
  1015. } else if dialect.DBType() == core.ORACLE {
  1016. if statement.Start != 0 || statement.LimitN != 0 {
  1017. a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
  1018. }
  1019. }
  1020. }
  1021. if statement.IsForUpdate {
  1022. a = dialect.ForUpdateSql(a)
  1023. }
  1024. return
  1025. }
  1026. func (statement *Statement) processIDParam() error {
  1027. if statement.idParam == nil || statement.RefTable == nil {
  1028. return nil
  1029. }
  1030. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1031. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1032. len(statement.RefTable.PrimaryKeys),
  1033. len(*statement.idParam),
  1034. )
  1035. }
  1036. for i, col := range statement.RefTable.PKColumns() {
  1037. var colName = statement.colName(col, statement.TableName())
  1038. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1039. }
  1040. return nil
  1041. }
  1042. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1043. var colnames = make([]string, len(cols))
  1044. for i, col := range cols {
  1045. if includeTableName {
  1046. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1047. "." + statement.Engine.Quote(col.Name)
  1048. } else {
  1049. colnames[i] = statement.Engine.Quote(col.Name)
  1050. }
  1051. }
  1052. return strings.Join(colnames, ", ")
  1053. }
  1054. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1055. if statement.RefTable != nil {
  1056. cols := statement.RefTable.PKColumns()
  1057. if len(cols) == 0 {
  1058. return ""
  1059. }
  1060. colstrs := statement.joinColumns(cols, false)
  1061. sqls := splitNNoCase(sqlStr, " from ", 2)
  1062. if len(sqls) != 2 {
  1063. return ""
  1064. }
  1065. var top string
  1066. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1067. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1068. }
  1069. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1070. return newsql
  1071. }
  1072. return ""
  1073. }
  1074. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1075. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1076. return "", ""
  1077. }
  1078. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1079. sqls := splitNNoCase(sqlStr, "where", 2)
  1080. if len(sqls) != 2 {
  1081. if len(sqls) == 1 {
  1082. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1083. colstrs, statement.Engine.Quote(statement.TableName()))
  1084. }
  1085. return "", ""
  1086. }
  1087. var whereStr = sqls[1]
  1088. //TODO: for postgres only, if any other database?
  1089. var paraStr string
  1090. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1091. paraStr = "$"
  1092. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1093. paraStr = ":"
  1094. }
  1095. if paraStr != "" {
  1096. if strings.Contains(sqls[1], paraStr) {
  1097. dollers := strings.Split(sqls[1], paraStr)
  1098. whereStr = dollers[0]
  1099. for i, c := range dollers[1:] {
  1100. ccs := strings.SplitN(c, " ", 2)
  1101. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1102. }
  1103. }
  1104. }
  1105. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1106. colstrs, statement.Engine.Quote(statement.TableName()),
  1107. whereStr)
  1108. }