emit_zod.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package main
  2. import (
  3. "fmt"
  4. "io"
  5. "sort"
  6. "strings"
  7. )
  8. func emitZod(w io.Writer, schemas []Schema, aliases []Alias) error {
  9. if _, err := fmt.Fprintln(w, zodHeader); err != nil {
  10. return err
  11. }
  12. for _, a := range sortAliases(aliases) {
  13. if _, err := fmt.Fprintf(w, "export const %sSchema = %s;\n", a.Name, zodTypeExpr(a.Underlying)); err != nil {
  14. return err
  15. }
  16. if _, err := fmt.Fprintf(w, "export type %s = z.infer<typeof %sSchema>;\n\n", a.Name, a.Name); err != nil {
  17. return err
  18. }
  19. }
  20. for _, s := range sortSchemas(schemas) {
  21. if _, err := fmt.Fprintf(w, "export const %sSchema = z.object({\n", s.Name); err != nil {
  22. return err
  23. }
  24. fields := append([]Field(nil), s.Fields...)
  25. sort.SliceStable(fields, func(i, j int) bool { return fields[i].JSONName < fields[j].JSONName })
  26. for _, f := range fields {
  27. line := fmt.Sprintf(" %s: %s,\n", quoteIfNeeded(f.JSONName), zodExpr(f))
  28. if _, err := fmt.Fprint(w, line); err != nil {
  29. return err
  30. }
  31. }
  32. if _, err := fmt.Fprintln(w, "});"); err != nil {
  33. return err
  34. }
  35. if _, err := fmt.Fprintf(w, "export type %s = z.infer<typeof %sSchema>;\n\n", s.Name, s.Name); err != nil {
  36. return err
  37. }
  38. }
  39. return nil
  40. }
  41. func zodExpr(f Field) string {
  42. expr := zodTypeExpr(f.Type)
  43. expr = applyZodValidations(expr, f.Type, f.Validate)
  44. if f.Optional {
  45. expr += ".optional()"
  46. }
  47. return expr
  48. }
  49. func zodTypeExpr(t TypeRef) string {
  50. switch t.Kind {
  51. case KindString:
  52. return "z.string()"
  53. case KindBool:
  54. return "z.boolean()"
  55. case KindInt:
  56. return "z.number().int()"
  57. case KindNumber:
  58. return "z.number()"
  59. case KindAny, KindUnknown:
  60. return "z.unknown()"
  61. case KindRaw:
  62. return "z.unknown()"
  63. case KindArray:
  64. return "z.array(" + zodTypeExpr(*t.Element) + ")"
  65. case KindMap:
  66. return "z.record(" + zodTypeExpr(*t.Key) + ", " + zodTypeExpr(*t.Value) + ")"
  67. case KindRef:
  68. if t.Name == "nullable" {
  69. return zodTypeExpr(*t.Inner) + ".nullable()"
  70. }
  71. return "z.lazy(() => " + t.Name + "Schema)"
  72. }
  73. return "z.unknown()"
  74. }
  75. func applyZodValidations(expr string, t TypeRef, rules []ValidateRule) string {
  76. for _, r := range rules {
  77. switch r.Name {
  78. case "required":
  79. continue
  80. case "omitempty":
  81. continue
  82. case "gte":
  83. if t.Kind == KindInt || t.Kind == KindNumber {
  84. expr += fmt.Sprintf(".min(%s)", r.Param)
  85. }
  86. case "lte":
  87. if t.Kind == KindInt || t.Kind == KindNumber {
  88. expr += fmt.Sprintf(".max(%s)", r.Param)
  89. }
  90. case "gt":
  91. if t.Kind == KindInt || t.Kind == KindNumber {
  92. expr += fmt.Sprintf(".gt(%s)", r.Param)
  93. }
  94. case "lt":
  95. if t.Kind == KindInt || t.Kind == KindNumber {
  96. expr += fmt.Sprintf(".lt(%s)", r.Param)
  97. }
  98. case "min":
  99. if t.Kind == KindString {
  100. expr += fmt.Sprintf(".min(%s)", r.Param)
  101. } else if t.Kind == KindInt || t.Kind == KindNumber {
  102. expr += fmt.Sprintf(".min(%s)", r.Param)
  103. }
  104. case "max":
  105. if t.Kind == KindString {
  106. expr += fmt.Sprintf(".max(%s)", r.Param)
  107. } else if t.Kind == KindInt || t.Kind == KindNumber {
  108. expr += fmt.Sprintf(".max(%s)", r.Param)
  109. }
  110. case "url":
  111. expr += ".url()"
  112. case "email":
  113. expr += ".email()"
  114. case "oneof":
  115. values := strings.Fields(r.Param)
  116. quoted := make([]string, 0, len(values))
  117. for _, v := range values {
  118. quoted = append(quoted, fmt.Sprintf("'%s'", v))
  119. }
  120. expr = fmt.Sprintf("z.enum([%s])", strings.Join(quoted, ", "))
  121. }
  122. }
  123. return expr
  124. }
  125. func quoteIfNeeded(name string) string {
  126. if name == "" {
  127. return "''"
  128. }
  129. for i, r := range name {
  130. if r >= 'a' && r <= 'z' {
  131. continue
  132. }
  133. if r >= 'A' && r <= 'Z' {
  134. continue
  135. }
  136. if r == '_' || r == '$' {
  137. continue
  138. }
  139. if i > 0 && r >= '0' && r <= '9' {
  140. continue
  141. }
  142. return "'" + name + "'"
  143. }
  144. return name
  145. }
  146. const zodHeader = `// Code generated by tools/openapigen. DO NOT EDIT.
  147. import { z } from 'zod';`