gen-accessors.go 8.4 KB


  1. // Copyright 2017 The go-github AUTHORS. All rights reserved.
  2. //
  3. // Use of this source code is governed by a BSD-style
  4. // license that can be found in the LICENSE file.
  5. // +build ignore
  6. // gen-accessors generates accessor methods for structs with pointer fields.
  7. //
  8. // It is meant to be used by the go-github authors in conjunction with the
  9. // go generate tool before sending a commit to GitHub.
  10. package main
  11. import (
  12. "bytes"
  13. "flag"
  14. "fmt"
  15. "go/ast"
  16. "go/format"
  17. "go/parser"
  18. "go/token"
  19. "io/ioutil"
  20. "log"
  21. "os"
  22. "sort"
  23. "strings"
  24. "text/template"
  25. )
  26. const (
  27. fileSuffix = "-accessors.go"
  28. )
  29. var (
  30. verbose = flag.Bool("v", false, "Print verbose log messages")
  31. sourceTmpl = template.Must(template.New("source").Parse(source))
  32. // blacklistStructMethod lists "struct.method" combos to skip.
  33. blacklistStructMethod = map[string]bool{
  34. "RepositoryContent.GetContent": true,
  35. "Client.GetBaseURL": true,
  36. "Client.GetUploadURL": true,
  37. "ErrorResponse.GetResponse": true,
  38. "RateLimitError.GetResponse": true,
  39. "AbuseRateLimitError.GetResponse": true,
  40. }
  41. // blacklistStruct lists structs to skip.
  42. blacklistStruct = map[string]bool{
  43. "Client": true,
  44. }
  45. )
  46. func logf(fmt string, args ...interface{}) {
  47. if *verbose {
  48. log.Printf(fmt, args...)
  49. }
  50. }
  51. func main() {
  52. flag.Parse()
  53. fset := token.NewFileSet()
  54. pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
  55. if err != nil {
  56. log.Fatal(err)
  57. return
  58. }
  59. for pkgName, pkg := range pkgs {
  60. t := &templateData{
  61. filename: pkgName + fileSuffix,
  62. Year: 2017,
  63. Package: pkgName,
  64. Imports: map[string]string{},
  65. }
  66. for filename, f := range pkg.Files {
  67. logf("Processing %v...", filename)
  68. if err := t.processAST(f); err != nil {
  69. log.Fatal(err)
  70. }
  71. }
  72. if err := t.dump(); err != nil {
  73. log.Fatal(err)
  74. }
  75. }
  76. logf("Done.")
  77. }
  78. func (t *templateData) processAST(f *ast.File) error {
  79. for _, decl := range f.Decls {
  80. gd, ok := decl.(*ast.GenDecl)
  81. if !ok {
  82. continue
  83. }
  84. for _, spec := range gd.Specs {
  85. ts, ok := spec.(*ast.TypeSpec)
  86. if !ok {
  87. continue
  88. }
  89. // Skip unexported identifiers.
  90. if !ts.Name.IsExported() {
  91. logf("Struct %v is unexported; skipping.", ts.Name)
  92. continue
  93. }
  94. // Check if the struct is blacklisted.
  95. if blacklistStruct[ts.Name.Name] {
  96. logf("Struct %v is blacklisted; skipping.", ts.Name)
  97. continue
  98. }
  99. st, ok := ts.Type.(*ast.StructType)
  100. if !ok {
  101. continue
  102. }
  103. for _, field := range st.Fields.List {
  104. se, ok := field.Type.(*ast.StarExpr)
  105. if len(field.Names) == 0 || !ok {
  106. continue
  107. }
  108. fieldName := field.Names[0]
  109. // Skip unexported identifiers.
  110. if !fieldName.IsExported() {
  111. logf("Field %v is unexported; skipping.", fieldName)
  112. continue
  113. }
  114. // Check if "struct.method" is blacklisted.
  115. if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklistStructMethod[key] {
  116. logf("Method %v is blacklisted; skipping.", key)
  117. continue
  118. }
  119. switch x := se.X.(type) {
  120. case *ast.ArrayType:
  121. t.addArrayType(x, ts.Name.String(), fieldName.String())
  122. case *ast.Ident:
  123. t.addIdent(x, ts.Name.String(), fieldName.String())
  124. case *ast.MapType:
  125. t.addMapType(x, ts.Name.String(), fieldName.String())
  126. case *ast.SelectorExpr:
  127. t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
  128. default:
  129. logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
  130. }
  131. }
  132. }
  133. }
  134. return nil
  135. }
  136. func sourceFilter(fi os.FileInfo) bool {
  137. return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
  138. }
  139. func (t *templateData) dump() error {
  140. if len(t.Getters) == 0 {
  141. logf("No getters for %v; skipping.", t.filename)
  142. return nil
  143. }
  144. // Sort getters by ReceiverType.FieldName.
  145. sort.Sort(byName(t.Getters))
  146. var buf bytes.Buffer
  147. if err := sourceTmpl.Execute(&buf, t); err != nil {
  148. return err
  149. }
  150. clean, err := format.Source(buf.Bytes())
  151. if err != nil {
  152. return err
  153. }
  154. logf("Writing %v...", t.filename)
  155. return ioutil.WriteFile(t.filename, clean, 0644)
  156. }
  157. func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
  158. return &getter{
  159. sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
  160. ReceiverVar: strings.ToLower(receiverType[:1]),
  161. ReceiverType: receiverType,
  162. FieldName: fieldName,
  163. FieldType: fieldType,
  164. ZeroValue: zeroValue,
  165. NamedStruct: namedStruct,
  166. }
  167. }
  168. func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
  169. var eltType string
  170. switch elt := x.Elt.(type) {
  171. case *ast.Ident:
  172. eltType = elt.String()
  173. default:
  174. logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
  175. return
  176. }
  177. t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
  178. }
  179. func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
  180. var zeroValue string
  181. var namedStruct = false
  182. switch x.String() {
  183. case "int", "int64":
  184. zeroValue = "0"
  185. case "string":
  186. zeroValue = `""`
  187. case "bool":
  188. zeroValue = "false"
  189. case "Timestamp":
  190. zeroValue = "Timestamp{}"
  191. default:
  192. zeroValue = "nil"
  193. namedStruct = true
  194. }
  195. t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
  196. }
  197. func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) {
  198. var keyType string
  199. switch key := x.Key.(type) {
  200. case *ast.Ident:
  201. keyType = key.String()
  202. default:
  203. logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
  204. return
  205. }
  206. var valueType string
  207. switch value := x.Value.(type) {
  208. case *ast.Ident:
  209. valueType = value.String()
  210. default:
  211. logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
  212. return
  213. }
  214. fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
  215. zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
  216. t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
  217. }
  218. func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
  219. if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
  220. return
  221. }
  222. var xX string
  223. if xx, ok := x.X.(*ast.Ident); ok {
  224. xX = xx.String()
  225. }
  226. switch xX {
  227. case "time", "json":
  228. if xX == "json" {
  229. t.Imports["encoding/json"] = "encoding/json"
  230. } else {
  231. t.Imports[xX] = xX
  232. }
  233. fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
  234. zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
  235. if xX == "time" && x.Sel.Name == "Duration" {
  236. zeroValue = "0"
  237. }
  238. t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
  239. default:
  240. logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
  241. }
  242. }
  243. type templateData struct {
  244. filename string
  245. Year int
  246. Package string
  247. Imports map[string]string
  248. Getters []*getter
  249. }
  250. type getter struct {
  251. sortVal string // Lower-case version of "ReceiverType.FieldName".
  252. ReceiverVar string // The one-letter variable name to match the ReceiverType.
  253. ReceiverType string
  254. FieldName string
  255. FieldType string
  256. ZeroValue string
  257. NamedStruct bool // Getter for named struct.
  258. }
  259. type byName []*getter
  260. func (b byName) Len() int { return len(b) }
  261. func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
  262. func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
  263. const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
  264. //
  265. // Use of this source code is governed by a BSD-style
  266. // license that can be found in the LICENSE file.
  267. // Code generated by gen-accessors; DO NOT EDIT.
  268. package {{.Package}}
  269. {{with .Imports}}
  270. import (
  271. {{- range . -}}
  272. "{{.}}"
  273. {{end -}}
  274. )
  275. {{end}}
  276. {{range .Getters}}
  277. {{if .NamedStruct}}
  278. // Get{{.FieldName}} returns the {{.FieldName}} field.
  279. func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
  280. if {{.ReceiverVar}} == nil {
  281. return {{.ZeroValue}}
  282. }
  283. return {{.ReceiverVar}}.{{.FieldName}}
  284. }
  285. {{else}}
  286. // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
  287. func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
  288. if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
  289. return {{.ZeroValue}}
  290. }
  291. return *{{.ReceiverVar}}.{{.FieldName}}
  292. }
  293. {{end}}
  294. {{end}}
  295. `