walker.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package main
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/parser"
  6. "go/token"
  7. "io/fs"
  8. "path/filepath"
  9. "strings"
  10. )
  11. type walkOverride struct {
  12. Field string
  13. Kind TypeKind
  14. }
  15. type packageRequest struct {
  16. Path string
  17. StructAllow map[string]bool
  18. AliasAllow map[string]bool
  19. Overrides map[string][]walkOverride
  20. }
  21. func walkPackages(requests []packageRequest) ([]Schema, []Alias, error) {
  22. fset := token.NewFileSet()
  23. var schemas []Schema
  24. var aliases []Alias
  25. for _, req := range requests {
  26. dir := req.Path
  27. pkgs, err := parser.ParseDir(fset, dir, func(fi fs.FileInfo) bool {
  28. return !strings.HasSuffix(fi.Name(), "_test.go")
  29. }, parser.ParseComments)
  30. if err != nil {
  31. return nil, nil, fmt.Errorf("parse %s: %w", dir, err)
  32. }
  33. for _, pkg := range pkgs {
  34. for _, file := range pkg.Files {
  35. for _, decl := range file.Decls {
  36. gen, ok := decl.(*ast.GenDecl)
  37. if !ok || gen.Tok != token.TYPE {
  38. continue
  39. }
  40. for _, spec := range gen.Specs {
  41. ts, ok := spec.(*ast.TypeSpec)
  42. if !ok {
  43. continue
  44. }
  45. if strct, ok := ts.Type.(*ast.StructType); ok {
  46. if req.StructAllow != nil && !req.StructAllow[ts.Name.Name] {
  47. continue
  48. }
  49. s := Schema{
  50. Name: ts.Name.Name,
  51. Package: pkg.Name,
  52. Doc: collectDoc(gen.Doc, ts.Doc),
  53. }
  54. overrides := req.Overrides[ts.Name.Name]
  55. for _, fld := range strct.Fields.List {
  56. s.Fields = append(s.Fields, buildFields(fld, overrides)...)
  57. }
  58. schemas = append(schemas, s)
  59. continue
  60. }
  61. if req.AliasAllow != nil && !req.AliasAllow[ts.Name.Name] {
  62. continue
  63. }
  64. aliases = append(aliases, Alias{
  65. Name: ts.Name.Name,
  66. Package: pkg.Name,
  67. Underlying: exprToType(ts.Type),
  68. })
  69. }
  70. }
  71. }
  72. }
  73. }
  74. return schemas, aliases, nil
  75. }
  76. func collectDoc(group ...*ast.CommentGroup) string {
  77. var b strings.Builder
  78. for _, g := range group {
  79. if g == nil {
  80. continue
  81. }
  82. for _, c := range g.List {
  83. line := strings.TrimPrefix(c.Text, "// ")
  84. line = strings.TrimPrefix(line, "//")
  85. b.WriteString(strings.TrimSpace(line))
  86. b.WriteByte('\n')
  87. }
  88. }
  89. return strings.TrimSpace(b.String())
  90. }
  91. func buildFields(fld *ast.Field, overrides []walkOverride) []Field {
  92. var fields []Field
  93. tag := ""
  94. if fld.Tag != nil {
  95. tag = fld.Tag.Value
  96. }
  97. jsonTag, validateTag, exampleTag, gormDash := parseStructTag(tag)
  98. if gormDash && jsonTag == "" {
  99. return nil
  100. }
  101. jsonName, omit, omitempty := parseJSONTag(jsonTag)
  102. if omit {
  103. return nil
  104. }
  105. validate := parseValidateTag(validateTag)
  106. doc := collectDoc(fld.Doc, fld.Comment)
  107. for _, n := range fld.Names {
  108. fname := jsonName
  109. if fname == "" {
  110. fname = lowerFirst(n.Name)
  111. }
  112. t := exprToType(fld.Type)
  113. for _, o := range overrides {
  114. if o.Field == n.Name || o.Field == jsonName {
  115. t = TypeRef{Kind: o.Kind}
  116. break
  117. }
  118. }
  119. fields = append(fields, Field{
  120. JSONName: fname,
  121. GoName: n.Name,
  122. Type: t,
  123. Optional: omitempty || isPointer(fld.Type),
  124. Validate: validate,
  125. Doc: doc,
  126. Example: exampleTag,
  127. })
  128. }
  129. if len(fld.Names) == 0 {
  130. fname := jsonName
  131. if fname == "" {
  132. fname = lowerFirst(exprIdentName(fld.Type))
  133. }
  134. t := exprToType(fld.Type)
  135. for _, o := range overrides {
  136. if o.Field == exprIdentName(fld.Type) || o.Field == jsonName {
  137. t = TypeRef{Kind: o.Kind}
  138. break
  139. }
  140. }
  141. fields = append(fields, Field{
  142. JSONName: fname,
  143. GoName: exprIdentName(fld.Type),
  144. Type: t,
  145. Optional: omitempty || isPointer(fld.Type),
  146. Validate: validate,
  147. Doc: doc,
  148. Example: exampleTag,
  149. })
  150. }
  151. return fields
  152. }
  153. func exprToType(expr ast.Expr) TypeRef {
  154. switch e := expr.(type) {
  155. case *ast.Ident:
  156. return identType(e.Name)
  157. case *ast.StarExpr:
  158. inner := exprToType(e.X)
  159. return TypeRef{Kind: KindRef, Name: "nullable", Inner: &inner}
  160. case *ast.ArrayType:
  161. elem := exprToType(e.Elt)
  162. return TypeRef{Kind: KindArray, Element: &elem}
  163. case *ast.MapType:
  164. k := exprToType(e.Key)
  165. v := exprToType(e.Value)
  166. return TypeRef{Kind: KindMap, Key: &k, Value: &v}
  167. case *ast.SelectorExpr:
  168. pkg := exprIdentName(e.X)
  169. name := e.Sel.Name
  170. if pkg == "json" && name == "RawMessage" {
  171. return TypeRef{Kind: KindAny}
  172. }
  173. if pkg == "time" && name == "Time" {
  174. return TypeRef{Kind: KindString, Name: "datetime"}
  175. }
  176. return TypeRef{Kind: KindRef, Name: name}
  177. case *ast.InterfaceType:
  178. return TypeRef{Kind: KindAny}
  179. default:
  180. return TypeRef{Kind: KindUnknown}
  181. }
  182. }
  183. func identType(name string) TypeRef {
  184. switch name {
  185. case "string":
  186. return TypeRef{Kind: KindString}
  187. case "bool":
  188. return TypeRef{Kind: KindBool}
  189. case "int", "int8", "int16", "int32", "int64",
  190. "uint", "uint8", "uint16", "uint32", "uint64":
  191. return TypeRef{Kind: KindInt}
  192. case "float32", "float64":
  193. return TypeRef{Kind: KindNumber}
  194. case "byte", "rune":
  195. return TypeRef{Kind: KindInt}
  196. case "any":
  197. return TypeRef{Kind: KindAny}
  198. default:
  199. return TypeRef{Kind: KindRef, Name: name}
  200. }
  201. }
  202. func isPointer(expr ast.Expr) bool {
  203. _, ok := expr.(*ast.StarExpr)
  204. return ok
  205. }
  206. func exprIdentName(expr ast.Expr) string {
  207. switch e := expr.(type) {
  208. case *ast.Ident:
  209. return e.Name
  210. case *ast.SelectorExpr:
  211. return e.Sel.Name
  212. case *ast.StarExpr:
  213. return exprIdentName(e.X)
  214. default:
  215. return ""
  216. }
  217. }
  218. func lowerFirst(s string) string {
  219. if s == "" {
  220. return s
  221. }
  222. return strings.ToLower(s[:1]) + s[1:]
  223. }
  224. func resolveRel(base, rel string) string {
  225. if filepath.IsAbs(rel) {
  226. return rel
  227. }
  228. return filepath.Clean(filepath.Join(base, rel))
  229. }