1
0

walker.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. package main
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/parser"
  6. "go/token"
  7. "io/fs"
  8. "path/filepath"
  9. "strings"
  10. )
  11. type walkOverride struct {
  12. Field string
  13. Kind TypeKind
  14. }
  15. type packageRequest struct {
  16. Path string
  17. StructAllow map[string]bool
  18. AliasAllow map[string]bool
  19. Overrides map[string][]walkOverride
  20. }
  21. func walkPackages(requests []packageRequest) ([]Schema, []Alias, error) {
  22. fset := token.NewFileSet()
  23. var schemas []Schema
  24. var aliases []Alias
  25. for _, req := range requests {
  26. dir := req.Path
  27. pkgs, err := parser.ParseDir(fset, dir, func(fi fs.FileInfo) bool {
  28. return !strings.HasSuffix(fi.Name(), "_test.go")
  29. }, parser.ParseComments)
  30. if err != nil {
  31. return nil, nil, fmt.Errorf("parse %s: %w", dir, err)
  32. }
  33. for _, pkg := range pkgs {
  34. for _, file := range pkg.Files {
  35. for _, decl := range file.Decls {
  36. gen, ok := decl.(*ast.GenDecl)
  37. if !ok || gen.Tok != token.TYPE {
  38. continue
  39. }
  40. for _, spec := range gen.Specs {
  41. ts, ok := spec.(*ast.TypeSpec)
  42. if !ok {
  43. continue
  44. }
  45. if strct, ok := ts.Type.(*ast.StructType); ok {
  46. if req.StructAllow != nil && !req.StructAllow[ts.Name.Name] {
  47. continue
  48. }
  49. s := Schema{
  50. Name: ts.Name.Name,
  51. Package: pkg.Name,
  52. Doc: collectDoc(gen.Doc, ts.Doc),
  53. }
  54. overrides := req.Overrides[ts.Name.Name]
  55. for _, fld := range strct.Fields.List {
  56. for _, f := range buildFields(fld, overrides) {
  57. s.Fields = append(s.Fields, f)
  58. }
  59. }
  60. schemas = append(schemas, s)
  61. continue
  62. }
  63. if req.AliasAllow != nil && !req.AliasAllow[ts.Name.Name] {
  64. continue
  65. }
  66. aliases = append(aliases, Alias{
  67. Name: ts.Name.Name,
  68. Package: pkg.Name,
  69. Underlying: exprToType(ts.Type),
  70. })
  71. }
  72. }
  73. }
  74. }
  75. }
  76. return schemas, aliases, nil
  77. }
  78. func collectDoc(group ...*ast.CommentGroup) string {
  79. var b strings.Builder
  80. for _, g := range group {
  81. if g == nil {
  82. continue
  83. }
  84. for _, c := range g.List {
  85. line := strings.TrimPrefix(c.Text, "// ")
  86. line = strings.TrimPrefix(line, "//")
  87. b.WriteString(strings.TrimSpace(line))
  88. b.WriteByte('\n')
  89. }
  90. }
  91. return strings.TrimSpace(b.String())
  92. }
  93. func buildFields(fld *ast.Field, overrides []walkOverride) []Field {
  94. var fields []Field
  95. tag := ""
  96. if fld.Tag != nil {
  97. tag = fld.Tag.Value
  98. }
  99. jsonTag, validateTag, gormDash := parseStructTag(tag)
  100. if gormDash && jsonTag == "" {
  101. return nil
  102. }
  103. jsonName, omit, omitempty := parseJSONTag(jsonTag)
  104. if omit {
  105. return nil
  106. }
  107. validate := parseValidateTag(validateTag)
  108. doc := collectDoc(fld.Doc, fld.Comment)
  109. for _, n := range fld.Names {
  110. fname := jsonName
  111. if fname == "" {
  112. fname = lowerFirst(n.Name)
  113. }
  114. t := exprToType(fld.Type)
  115. for _, o := range overrides {
  116. if o.Field == n.Name || o.Field == jsonName {
  117. t = TypeRef{Kind: o.Kind}
  118. break
  119. }
  120. }
  121. fields = append(fields, Field{
  122. JSONName: fname,
  123. GoName: n.Name,
  124. Type: t,
  125. Optional: omitempty || isPointer(fld.Type),
  126. Validate: validate,
  127. Doc: doc,
  128. })
  129. }
  130. if len(fld.Names) == 0 {
  131. fname := jsonName
  132. if fname == "" {
  133. fname = lowerFirst(exprIdentName(fld.Type))
  134. }
  135. t := exprToType(fld.Type)
  136. for _, o := range overrides {
  137. if o.Field == exprIdentName(fld.Type) || o.Field == jsonName {
  138. t = TypeRef{Kind: o.Kind}
  139. break
  140. }
  141. }
  142. fields = append(fields, Field{
  143. JSONName: fname,
  144. GoName: exprIdentName(fld.Type),
  145. Type: t,
  146. Optional: omitempty || isPointer(fld.Type),
  147. Validate: validate,
  148. Doc: doc,
  149. })
  150. }
  151. return fields
  152. }
  153. func exprToType(expr ast.Expr) TypeRef {
  154. switch e := expr.(type) {
  155. case *ast.Ident:
  156. return identType(e.Name)
  157. case *ast.StarExpr:
  158. inner := exprToType(e.X)
  159. return TypeRef{Kind: KindRef, Name: "nullable", Inner: &inner}
  160. case *ast.ArrayType:
  161. elem := exprToType(e.Elt)
  162. return TypeRef{Kind: KindArray, Element: &elem}
  163. case *ast.MapType:
  164. k := exprToType(e.Key)
  165. v := exprToType(e.Value)
  166. return TypeRef{Kind: KindMap, Key: &k, Value: &v}
  167. case *ast.SelectorExpr:
  168. pkg := exprIdentName(e.X)
  169. name := e.Sel.Name
  170. if pkg == "json" && name == "RawMessage" {
  171. return TypeRef{Kind: KindAny}
  172. }
  173. if pkg == "time" && name == "Time" {
  174. return TypeRef{Kind: KindString, Name: "datetime"}
  175. }
  176. return TypeRef{Kind: KindRef, Name: name}
  177. case *ast.InterfaceType:
  178. return TypeRef{Kind: KindAny}
  179. default:
  180. return TypeRef{Kind: KindUnknown}
  181. }
  182. }
  183. func identType(name string) TypeRef {
  184. switch name {
  185. case "string":
  186. return TypeRef{Kind: KindString}
  187. case "bool":
  188. return TypeRef{Kind: KindBool}
  189. case "int", "int8", "int16", "int32", "int64",
  190. "uint", "uint8", "uint16", "uint32", "uint64":
  191. return TypeRef{Kind: KindInt}
  192. case "float32", "float64":
  193. return TypeRef{Kind: KindNumber}
  194. case "byte", "rune":
  195. return TypeRef{Kind: KindInt}
  196. case "any":
  197. return TypeRef{Kind: KindAny}
  198. default:
  199. return TypeRef{Kind: KindRef, Name: name}
  200. }
  201. }
  202. func isPointer(expr ast.Expr) bool {
  203. _, ok := expr.(*ast.StarExpr)
  204. return ok
  205. }
  206. func exprIdentName(expr ast.Expr) string {
  207. switch e := expr.(type) {
  208. case *ast.Ident:
  209. return e.Name
  210. case *ast.SelectorExpr:
  211. return e.Sel.Name
  212. case *ast.StarExpr:
  213. return exprIdentName(e.X)
  214. default:
  215. return ""
  216. }
  217. }
  218. func lowerFirst(s string) string {
  219. if s == "" {
  220. return s
  221. }
  222. return strings.ToLower(s[:1]) + s[1:]
  223. }
  224. func resolveRel(base, rel string) string {
  225. if filepath.IsAbs(rel) {
  226. return rel
  227. }
  228. return filepath.Clean(filepath.Join(base, rel))
  229. }