main.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package main
  2. import (
  3. "log"
  4. "os"
  5. "strings"
  6. "github.com/olekukonko/tablewriter"
  7. "github.com/pkg/errors"
  8. "gopkg.in/DATA-DOG/go-sqlmock.v2"
  9. "gorm.io/driver/mysql"
  10. "gorm.io/driver/postgres"
  11. "gorm.io/driver/sqlite"
  12. "gorm.io/gorm"
  13. "gorm.io/gorm/clause"
  14. "gorm.io/gorm/schema"
  15. "gogs.io/gogs/internal/db"
  16. )
  17. //go:generate go run main.go ../../../docs/dev/database_schema.md
  18. func main() {
  19. w, err := os.Create(os.Args[1])
  20. if err != nil {
  21. log.Fatalf("Failed to create file: %v", err)
  22. }
  23. defer func() { _ = w.Close() }()
  24. conn, _, err := sqlmock.New()
  25. if err != nil {
  26. log.Fatalf("Failed to get mock connection: %v", err)
  27. }
  28. defer func() { _ = conn.Close() }()
  29. dialectors := []gorm.Dialector{
  30. postgres.New(postgres.Config{
  31. Conn: conn,
  32. }),
  33. mysql.New(mysql.Config{
  34. Conn: conn,
  35. SkipInitializeWithVersion: true,
  36. }),
  37. sqlite.Open(""),
  38. }
  39. collected := make([][]*tableInfo, 0, len(dialectors))
  40. for i, dialector := range dialectors {
  41. tableInfos, err := generate(dialector)
  42. if err != nil {
  43. log.Fatalf("Failed to get table info of %d: %v", i, err)
  44. }
  45. collected = append(collected, tableInfos)
  46. }
  47. for i, ti := range collected[0] {
  48. _, _ = w.WriteString(`# Table "` + ti.Name + `"`)
  49. _, _ = w.WriteString("\n\n")
  50. _, _ = w.WriteString("```\n")
  51. table := tablewriter.NewWriter(w)
  52. table.SetHeader([]string{"Field", "Column", "PostgreSQL", "MySQL", "SQLite3"})
  53. table.SetBorder(false)
  54. for j, f := range ti.Fields {
  55. table.Append([]string{
  56. f.Name, f.Column,
  57. strings.ToUpper(f.Type), // PostgreSQL
  58. strings.ToUpper(collected[1][i].Fields[j].Type), // MySQL
  59. strings.ToUpper(collected[2][i].Fields[j].Type), // SQLite3
  60. })
  61. }
  62. table.Render()
  63. _, _ = w.WriteString("\n")
  64. _, _ = w.WriteString("Primary keys: ")
  65. _, _ = w.WriteString(strings.Join(ti.PrimaryKeys, ", "))
  66. _, _ = w.WriteString("\n")
  67. _, _ = w.WriteString("```\n\n")
  68. }
  69. }
  70. type tableField struct {
  71. Name string
  72. Column string
  73. Type string
  74. }
  75. type tableInfo struct {
  76. Name string
  77. Fields []*tableField
  78. PrimaryKeys []string
  79. }
  80. // This function is derived from gorm.io/gorm/migrator/migrator.go:Migrator.CreateTable.
  81. func generate(dialector gorm.Dialector) ([]*tableInfo, error) {
  82. conn, err := gorm.Open(dialector, &gorm.Config{
  83. NamingStrategy: schema.NamingStrategy{
  84. SingularTable: true,
  85. },
  86. DryRun: true,
  87. DisableAutomaticPing: true,
  88. })
  89. if err != nil {
  90. return nil, errors.Wrap(err, "open database")
  91. }
  92. m := conn.Migrator().(interface {
  93. RunWithValue(value interface{}, fc func(*gorm.Statement) error) error
  94. FullDataTypeOf(*schema.Field) clause.Expr
  95. })
  96. tableInfos := make([]*tableInfo, 0, len(db.Tables))
  97. for _, table := range db.Tables {
  98. err = m.RunWithValue(table, func(stmt *gorm.Statement) error {
  99. fields := make([]*tableField, 0, len(stmt.Schema.DBNames))
  100. for _, field := range stmt.Schema.Fields {
  101. if field.DBName == "" {
  102. continue
  103. }
  104. fields = append(fields, &tableField{
  105. Name: field.Name,
  106. Column: field.DBName,
  107. Type: m.FullDataTypeOf(field).SQL,
  108. })
  109. }
  110. primaryKeys := make([]string, 0, len(stmt.Schema.PrimaryFields))
  111. if len(stmt.Schema.PrimaryFields) > 0 {
  112. for _, field := range stmt.Schema.PrimaryFields {
  113. primaryKeys = append(primaryKeys, field.DBName)
  114. }
  115. }
  116. tableInfos = append(tableInfos, &tableInfo{
  117. Name: stmt.Table,
  118. Fields: fields,
  119. PrimaryKeys: primaryKeys,
  120. })
  121. return nil
  122. })
  123. if err != nil {
  124. return nil, errors.Wrap(err, "gather table information")
  125. }
  126. }
  127. return tableInfos, nil
  128. }