123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- // Copyright 2017 The go-github AUTHORS. All rights reserved.
- //
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // +build ignore
- // gen-accessors generates accessor methods for structs with pointer fields.
- //
- // It is meant to be used by the go-github authors in conjunction with the
- // go generate tool before sending a commit to GitHub.
- package main
- import (
- "bytes"
- "flag"
- "fmt"
- "go/ast"
- "go/format"
- "go/parser"
- "go/token"
- "io/ioutil"
- "log"
- "os"
- "sort"
- "strings"
- "text/template"
- )
- const (
- fileSuffix = "-accessors.go"
- )
- var (
- verbose = flag.Bool("v", false, "Print verbose log messages")
- sourceTmpl = template.Must(template.New("source").Parse(source))
- // blacklistStructMethod lists "struct.method" combos to skip.
- blacklistStructMethod = map[string]bool{
- "RepositoryContent.GetContent": true,
- "Client.GetBaseURL": true,
- "Client.GetUploadURL": true,
- "ErrorResponse.GetResponse": true,
- "RateLimitError.GetResponse": true,
- "AbuseRateLimitError.GetResponse": true,
- }
- // blacklistStruct lists structs to skip.
- blacklistStruct = map[string]bool{
- "Client": true,
- }
- )
- func logf(fmt string, args ...interface{}) {
- if *verbose {
- log.Printf(fmt, args...)
- }
- }
- func main() {
- flag.Parse()
- fset := token.NewFileSet()
- pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
- if err != nil {
- log.Fatal(err)
- return
- }
- for pkgName, pkg := range pkgs {
- t := &templateData{
- filename: pkgName + fileSuffix,
- Year: 2017,
- Package: pkgName,
- Imports: map[string]string{},
- }
- for filename, f := range pkg.Files {
- logf("Processing %v...", filename)
- if err := t.processAST(f); err != nil {
- log.Fatal(err)
- }
- }
- if err := t.dump(); err != nil {
- log.Fatal(err)
- }
- }
- logf("Done.")
- }
- func (t *templateData) processAST(f *ast.File) error {
- for _, decl := range f.Decls {
- gd, ok := decl.(*ast.GenDecl)
- if !ok {
- continue
- }
- for _, spec := range gd.Specs {
- ts, ok := spec.(*ast.TypeSpec)
- if !ok {
- continue
- }
- // Skip unexported identifiers.
- if !ts.Name.IsExported() {
- logf("Struct %v is unexported; skipping.", ts.Name)
- continue
- }
- // Check if the struct is blacklisted.
- if blacklistStruct[ts.Name.Name] {
- logf("Struct %v is blacklisted; skipping.", ts.Name)
- continue
- }
- st, ok := ts.Type.(*ast.StructType)
- if !ok {
- continue
- }
- for _, field := range st.Fields.List {
- se, ok := field.Type.(*ast.StarExpr)
- if len(field.Names) == 0 || !ok {
- continue
- }
- fieldName := field.Names[0]
- // Skip unexported identifiers.
- if !fieldName.IsExported() {
- logf("Field %v is unexported; skipping.", fieldName)
- continue
- }
- // Check if "struct.method" is blacklisted.
- if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklistStructMethod[key] {
- logf("Method %v is blacklisted; skipping.", key)
- continue
- }
- switch x := se.X.(type) {
- case *ast.ArrayType:
- t.addArrayType(x, ts.Name.String(), fieldName.String())
- case *ast.Ident:
- t.addIdent(x, ts.Name.String(), fieldName.String())
- case *ast.MapType:
- t.addMapType(x, ts.Name.String(), fieldName.String())
- case *ast.SelectorExpr:
- t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
- default:
- logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
- }
- }
- }
- }
- return nil
- }
- func sourceFilter(fi os.FileInfo) bool {
- return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
- }
- func (t *templateData) dump() error {
- if len(t.Getters) == 0 {
- logf("No getters for %v; skipping.", t.filename)
- return nil
- }
- // Sort getters by ReceiverType.FieldName.
- sort.Sort(byName(t.Getters))
- var buf bytes.Buffer
- if err := sourceTmpl.Execute(&buf, t); err != nil {
- return err
- }
- clean, err := format.Source(buf.Bytes())
- if err != nil {
- return err
- }
- logf("Writing %v...", t.filename)
- return ioutil.WriteFile(t.filename, clean, 0644)
- }
- func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
- return &getter{
- sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
- ReceiverVar: strings.ToLower(receiverType[:1]),
- ReceiverType: receiverType,
- FieldName: fieldName,
- FieldType: fieldType,
- ZeroValue: zeroValue,
- NamedStruct: namedStruct,
- }
- }
- func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
- var eltType string
- switch elt := x.Elt.(type) {
- case *ast.Ident:
- eltType = elt.String()
- default:
- logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
- return
- }
- t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
- }
- func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
- var zeroValue string
- var namedStruct = false
- switch x.String() {
- case "int", "int64":
- zeroValue = "0"
- case "string":
- zeroValue = `""`
- case "bool":
- zeroValue = "false"
- case "Timestamp":
- zeroValue = "Timestamp{}"
- default:
- zeroValue = "nil"
- namedStruct = true
- }
- t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
- }
- func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) {
- var keyType string
- switch key := x.Key.(type) {
- case *ast.Ident:
- keyType = key.String()
- default:
- logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
- return
- }
- var valueType string
- switch value := x.Value.(type) {
- case *ast.Ident:
- valueType = value.String()
- default:
- logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
- return
- }
- fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
- zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
- t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
- }
- func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
- if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
- return
- }
- var xX string
- if xx, ok := x.X.(*ast.Ident); ok {
- xX = xx.String()
- }
- switch xX {
- case "time", "json":
- if xX == "json" {
- t.Imports["encoding/json"] = "encoding/json"
- } else {
- t.Imports[xX] = xX
- }
- fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
- zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
- if xX == "time" && x.Sel.Name == "Duration" {
- zeroValue = "0"
- }
- t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
- default:
- logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
- }
- }
- type templateData struct {
- filename string
- Year int
- Package string
- Imports map[string]string
- Getters []*getter
- }
- type getter struct {
- sortVal string // Lower-case version of "ReceiverType.FieldName".
- ReceiverVar string // The one-letter variable name to match the ReceiverType.
- ReceiverType string
- FieldName string
- FieldType string
- ZeroValue string
- NamedStruct bool // Getter for named struct.
- }
- type byName []*getter
- func (b byName) Len() int { return len(b) }
- func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
- func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
- const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
- //
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // Code generated by gen-accessors; DO NOT EDIT.
- package {{.Package}}
- {{with .Imports}}
- import (
- {{- range . -}}
- "{{.}}"
- {{end -}}
- )
- {{end}}
- {{range .Getters}}
- {{if .NamedStruct}}
- // Get{{.FieldName}} returns the {{.FieldName}} field.
- func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
- if {{.ReceiverVar}} == nil {
- return {{.ZeroValue}}
- }
- return {{.ReceiverVar}}.{{.FieldName}}
- }
- {{else}}
- // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
- func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
- if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
- return {{.ZeroValue}}
- }
- return *{{.ReceiverVar}}.{{.FieldName}}
- }
- {{end}}
- {{end}}
- `
|