callback.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. // Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
  2. //
  3. // Use of this source code is governed by an MIT-style
  4. // license that can be found in the LICENSE file.
  5. package sqlite3
  6. // You can't export a Go function to C and have definitions in the C
  7. // preamble in the same file, so we have to have callbackTrampoline in
  8. // its own file. Because we need a separate file anyway, the support
  9. // code for SQLite custom functions is in here.
  10. /*
  11. #ifndef USE_LIBSQLITE3
  12. #include <sqlite3-binding.h>
  13. #else
  14. #include <sqlite3.h>
  15. #endif
  16. #include <stdlib.h>
  17. void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
  18. void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
  19. */
  20. import "C"
  21. import (
  22. "errors"
  23. "fmt"
  24. "math"
  25. "reflect"
  26. "sync"
  27. "unsafe"
  28. )
  29. //export callbackTrampoline
  30. func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
  31. args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
  32. fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
  33. fi.Call(ctx, args)
  34. }
  35. //export stepTrampoline
  36. func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
  37. args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
  38. ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
  39. ai.Step(ctx, args)
  40. }
  41. //export doneTrampoline
  42. func doneTrampoline(ctx *C.sqlite3_context) {
  43. handle := uintptr(C.sqlite3_user_data(ctx))
  44. ai := lookupHandle(handle).(*aggInfo)
  45. ai.Done(ctx)
  46. }
  47. // Use handles to avoid passing Go pointers to C.
  48. type handleVal struct {
  49. db *SQLiteConn
  50. val interface{}
  51. }
  52. var handleLock sync.Mutex
  53. var handleVals = make(map[uintptr]handleVal)
  54. var handleIndex uintptr = 100
  55. func newHandle(db *SQLiteConn, v interface{}) uintptr {
  56. handleLock.Lock()
  57. defer handleLock.Unlock()
  58. i := handleIndex
  59. handleIndex++
  60. handleVals[i] = handleVal{db, v}
  61. return i
  62. }
  63. func lookupHandle(handle uintptr) interface{} {
  64. handleLock.Lock()
  65. defer handleLock.Unlock()
  66. r, ok := handleVals[handle]
  67. if !ok {
  68. if handle >= 100 && handle < handleIndex {
  69. panic("deleted handle")
  70. } else {
  71. panic("invalid handle")
  72. }
  73. }
  74. return r.val
  75. }
  76. func deleteHandles(db *SQLiteConn) {
  77. handleLock.Lock()
  78. defer handleLock.Unlock()
  79. for handle, val := range handleVals {
  80. if val.db == db {
  81. delete(handleVals, handle)
  82. }
  83. }
  84. }
  85. // This is only here so that tests can refer to it.
  86. type callbackArgRaw C.sqlite3_value
  87. type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
  88. type callbackArgCast struct {
  89. f callbackArgConverter
  90. typ reflect.Type
  91. }
  92. func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
  93. val, err := c.f(v)
  94. if err != nil {
  95. return reflect.Value{}, err
  96. }
  97. if !val.Type().ConvertibleTo(c.typ) {
  98. return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
  99. }
  100. return val.Convert(c.typ), nil
  101. }
  102. func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
  103. if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
  104. return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
  105. }
  106. return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
  107. }
  108. func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
  109. if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
  110. return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
  111. }
  112. i := int64(C.sqlite3_value_int64(v))
  113. val := false
  114. if i != 0 {
  115. val = true
  116. }
  117. return reflect.ValueOf(val), nil
  118. }
  119. func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
  120. if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
  121. return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
  122. }
  123. return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
  124. }
  125. func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
  126. switch C.sqlite3_value_type(v) {
  127. case C.SQLITE_BLOB:
  128. l := C.sqlite3_value_bytes(v)
  129. p := C.sqlite3_value_blob(v)
  130. return reflect.ValueOf(C.GoBytes(p, l)), nil
  131. case C.SQLITE_TEXT:
  132. l := C.sqlite3_value_bytes(v)
  133. c := unsafe.Pointer(C.sqlite3_value_text(v))
  134. return reflect.ValueOf(C.GoBytes(c, l)), nil
  135. default:
  136. return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
  137. }
  138. }
  139. func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
  140. switch C.sqlite3_value_type(v) {
  141. case C.SQLITE_BLOB:
  142. l := C.sqlite3_value_bytes(v)
  143. p := (*C.char)(C.sqlite3_value_blob(v))
  144. return reflect.ValueOf(C.GoStringN(p, l)), nil
  145. case C.SQLITE_TEXT:
  146. c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
  147. return reflect.ValueOf(C.GoString(c)), nil
  148. default:
  149. return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
  150. }
  151. }
  152. func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
  153. switch C.sqlite3_value_type(v) {
  154. case C.SQLITE_INTEGER:
  155. return callbackArgInt64(v)
  156. case C.SQLITE_FLOAT:
  157. return callbackArgFloat64(v)
  158. case C.SQLITE_TEXT:
  159. return callbackArgString(v)
  160. case C.SQLITE_BLOB:
  161. return callbackArgBytes(v)
  162. case C.SQLITE_NULL:
  163. // Interpret NULL as a nil byte slice.
  164. var ret []byte
  165. return reflect.ValueOf(ret), nil
  166. default:
  167. panic("unreachable")
  168. }
  169. }
  170. func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
  171. switch typ.Kind() {
  172. case reflect.Interface:
  173. if typ.NumMethod() != 0 {
  174. return nil, errors.New("the only supported interface type is interface{}")
  175. }
  176. return callbackArgGeneric, nil
  177. case reflect.Slice:
  178. if typ.Elem().Kind() != reflect.Uint8 {
  179. return nil, errors.New("the only supported slice type is []byte")
  180. }
  181. return callbackArgBytes, nil
  182. case reflect.String:
  183. return callbackArgString, nil
  184. case reflect.Bool:
  185. return callbackArgBool, nil
  186. case reflect.Int64:
  187. return callbackArgInt64, nil
  188. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
  189. c := callbackArgCast{callbackArgInt64, typ}
  190. return c.Run, nil
  191. case reflect.Float64:
  192. return callbackArgFloat64, nil
  193. case reflect.Float32:
  194. c := callbackArgCast{callbackArgFloat64, typ}
  195. return c.Run, nil
  196. default:
  197. return nil, fmt.Errorf("don't know how to convert to %s", typ)
  198. }
  199. }
  200. func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
  201. var args []reflect.Value
  202. if len(argv) < len(converters) {
  203. return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
  204. }
  205. for i, arg := range argv[:len(converters)] {
  206. v, err := converters[i](arg)
  207. if err != nil {
  208. return nil, err
  209. }
  210. args = append(args, v)
  211. }
  212. if variadic != nil {
  213. for _, arg := range argv[len(converters):] {
  214. v, err := variadic(arg)
  215. if err != nil {
  216. return nil, err
  217. }
  218. args = append(args, v)
  219. }
  220. }
  221. return args, nil
  222. }
  223. type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
  224. func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
  225. switch v.Type().Kind() {
  226. case reflect.Int64:
  227. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
  228. v = v.Convert(reflect.TypeOf(int64(0)))
  229. case reflect.Bool:
  230. b := v.Interface().(bool)
  231. if b {
  232. v = reflect.ValueOf(int64(1))
  233. } else {
  234. v = reflect.ValueOf(int64(0))
  235. }
  236. default:
  237. return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
  238. }
  239. C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
  240. return nil
  241. }
  242. func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
  243. switch v.Type().Kind() {
  244. case reflect.Float64:
  245. case reflect.Float32:
  246. v = v.Convert(reflect.TypeOf(float64(0)))
  247. default:
  248. return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
  249. }
  250. C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
  251. return nil
  252. }
  253. func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
  254. if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
  255. return fmt.Errorf("cannot convert %s to BLOB", v.Type())
  256. }
  257. i := v.Interface()
  258. if i == nil || len(i.([]byte)) == 0 {
  259. C.sqlite3_result_null(ctx)
  260. } else {
  261. bs := i.([]byte)
  262. C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
  263. }
  264. return nil
  265. }
  266. func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
  267. if v.Type().Kind() != reflect.String {
  268. return fmt.Errorf("cannot convert %s to TEXT", v.Type())
  269. }
  270. C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
  271. return nil
  272. }
  273. func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
  274. switch typ.Kind() {
  275. case reflect.Slice:
  276. if typ.Elem().Kind() != reflect.Uint8 {
  277. return nil, errors.New("the only supported slice type is []byte")
  278. }
  279. return callbackRetBlob, nil
  280. case reflect.String:
  281. return callbackRetText, nil
  282. case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
  283. return callbackRetInteger, nil
  284. case reflect.Float32, reflect.Float64:
  285. return callbackRetFloat, nil
  286. default:
  287. return nil, fmt.Errorf("don't know how to convert to %s", typ)
  288. }
  289. }
  290. func callbackError(ctx *C.sqlite3_context, err error) {
  291. cstr := C.CString(err.Error())
  292. defer C.free(unsafe.Pointer(cstr))
  293. C.sqlite3_result_error(ctx, cstr, -1)
  294. }
  295. // Test support code. Tests are not allowed to import "C", so we can't
  296. // declare any functions that use C.sqlite3_value.
  297. func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
  298. return func(*C.sqlite3_value) (reflect.Value, error) {
  299. return v, err
  300. }
  301. }