.gitignore added
[dotfiles/.git] / .config / coc / extensions / coc-go-data / tools / pkg / mod / golang.org / x / sys@v0.0.0-20210124154548-22da62e12c0c / windows / mkwinsyscall / mkwinsyscall.go
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 /*
6 mkwinsyscall generates windows system call bodies
7
8 It parses all files specified on command line containing function
9 prototypes (like syscall_windows.go) and prints system call bodies
10 to standard output.
11
12 The prototypes are marked by lines beginning with "//sys" and read
13 like func declarations if //sys is replaced by func, but:
14
15 * The parameter lists must give a name for each argument. This
16   includes return parameters.
17
18 * The parameter lists must give a type for each argument:
19   the (x, y, z int) shorthand is not allowed.
20
21 * If the return parameter is an error number, it must be named err.
22
23 * If go func name needs to be different from its winapi dll name,
24   the winapi name could be specified at the end, after "=" sign, like
25   //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
26
27 * Each function that returns err needs to supply a condition, that
28   return value of winapi will be tested against to detect failure.
29   This would set err to windows "last-error", otherwise it will be nil.
30   The value can be provided at end of //sys declaration, like
31   //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
32   and is [failretval==0] by default.
33
34 * If the function name ends in a "?", then the function not existing is non-
35   fatal, and an error will be returned instead of panicking.
36
37 Usage:
38         mkwinsyscall [flags] [path ...]
39
40 The flags are:
41         -output
42                 Specify output file name (outputs to console if blank).
43         -trace
44                 Generate print statement after every syscall.
45 */
46 package main
47
48 import (
49         "bufio"
50         "bytes"
51         "errors"
52         "flag"
53         "fmt"
54         "go/format"
55         "go/parser"
56         "go/token"
57         "io"
58         "io/ioutil"
59         "log"
60         "os"
61         "path/filepath"
62         "runtime"
63         "sort"
64         "strconv"
65         "strings"
66         "text/template"
67 )
68
69 var (
70         filename       = flag.String("output", "", "output file name (standard output if omitted)")
71         printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
72         systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
73 )
74
75 func trim(s string) string {
76         return strings.Trim(s, " \t")
77 }
78
79 var packageName string
80
81 func packagename() string {
82         return packageName
83 }
84
85 func syscalldot() string {
86         if packageName == "syscall" {
87                 return ""
88         }
89         return "syscall."
90 }
91
92 // Param is function parameter
93 type Param struct {
94         Name      string
95         Type      string
96         fn        *Fn
97         tmpVarIdx int
98 }
99
100 // tmpVar returns temp variable name that will be used to represent p during syscall.
101 func (p *Param) tmpVar() string {
102         if p.tmpVarIdx < 0 {
103                 p.tmpVarIdx = p.fn.curTmpVarIdx
104                 p.fn.curTmpVarIdx++
105         }
106         return fmt.Sprintf("_p%d", p.tmpVarIdx)
107 }
108
109 // BoolTmpVarCode returns source code for bool temp variable.
110 func (p *Param) BoolTmpVarCode() string {
111         const code = `var %[1]s uint32
112         if %[2]s {
113                 %[1]s = 1
114         }`
115         return fmt.Sprintf(code, p.tmpVar(), p.Name)
116 }
117
118 // BoolPointerTmpVarCode returns source code for bool temp variable.
119 func (p *Param) BoolPointerTmpVarCode() string {
120         const code = `var %[1]s uint32
121         if *%[2]s {
122                 %[1]s = 1
123         }`
124         return fmt.Sprintf(code, p.tmpVar(), p.Name)
125 }
126
127 // SliceTmpVarCode returns source code for slice temp variable.
128 func (p *Param) SliceTmpVarCode() string {
129         const code = `var %s *%s
130         if len(%s) > 0 {
131                 %s = &%s[0]
132         }`
133         tmp := p.tmpVar()
134         return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
135 }
136
137 // StringTmpVarCode returns source code for string temp variable.
138 func (p *Param) StringTmpVarCode() string {
139         errvar := p.fn.Rets.ErrorVarName()
140         if errvar == "" {
141                 errvar = "_"
142         }
143         tmp := p.tmpVar()
144         const code = `var %s %s
145         %s, %s = %s(%s)`
146         s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
147         if errvar == "-" {
148                 return s
149         }
150         const morecode = `
151         if %s != nil {
152                 return
153         }`
154         return s + fmt.Sprintf(morecode, errvar)
155 }
156
157 // TmpVarCode returns source code for temp variable.
158 func (p *Param) TmpVarCode() string {
159         switch {
160         case p.Type == "bool":
161                 return p.BoolTmpVarCode()
162         case p.Type == "*bool":
163                 return p.BoolPointerTmpVarCode()
164         case strings.HasPrefix(p.Type, "[]"):
165                 return p.SliceTmpVarCode()
166         default:
167                 return ""
168         }
169 }
170
171 // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable.
172 func (p *Param) TmpVarReadbackCode() string {
173         switch {
174         case p.Type == "*bool":
175                 return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar())
176         default:
177                 return ""
178         }
179 }
180
181 // TmpVarHelperCode returns source code for helper's temp variable.
182 func (p *Param) TmpVarHelperCode() string {
183         if p.Type != "string" {
184                 return ""
185         }
186         return p.StringTmpVarCode()
187 }
188
189 // SyscallArgList returns source code fragments representing p parameter
190 // in syscall. Slices are translated into 2 syscall parameters: pointer to
191 // the first element and length.
192 func (p *Param) SyscallArgList() []string {
193         t := p.HelperType()
194         var s string
195         switch {
196         case t == "*bool":
197                 s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar())
198         case t[0] == '*':
199                 s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
200         case t == "bool":
201                 s = p.tmpVar()
202         case strings.HasPrefix(t, "[]"):
203                 return []string{
204                         fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
205                         fmt.Sprintf("uintptr(len(%s))", p.Name),
206                 }
207         default:
208                 s = p.Name
209         }
210         return []string{fmt.Sprintf("uintptr(%s)", s)}
211 }
212
213 // IsError determines if p parameter is used to return error.
214 func (p *Param) IsError() bool {
215         return p.Name == "err" && p.Type == "error"
216 }
217
218 // HelperType returns type of parameter p used in helper function.
219 func (p *Param) HelperType() string {
220         if p.Type == "string" {
221                 return p.fn.StrconvType()
222         }
223         return p.Type
224 }
225
226 // join concatenates parameters ps into a string with sep separator.
227 // Each parameter is converted into string by applying fn to it
228 // before conversion.
229 func join(ps []*Param, fn func(*Param) string, sep string) string {
230         if len(ps) == 0 {
231                 return ""
232         }
233         a := make([]string, 0)
234         for _, p := range ps {
235                 a = append(a, fn(p))
236         }
237         return strings.Join(a, sep)
238 }
239
240 // Rets describes function return parameters.
241 type Rets struct {
242         Name          string
243         Type          string
244         ReturnsError  bool
245         FailCond      string
246         fnMaybeAbsent bool
247 }
248
249 // ErrorVarName returns error variable name for r.
250 func (r *Rets) ErrorVarName() string {
251         if r.ReturnsError {
252                 return "err"
253         }
254         if r.Type == "error" {
255                 return r.Name
256         }
257         return ""
258 }
259
260 // ToParams converts r into slice of *Param.
261 func (r *Rets) ToParams() []*Param {
262         ps := make([]*Param, 0)
263         if len(r.Name) > 0 {
264                 ps = append(ps, &Param{Name: r.Name, Type: r.Type})
265         }
266         if r.ReturnsError {
267                 ps = append(ps, &Param{Name: "err", Type: "error"})
268         }
269         return ps
270 }
271
272 // List returns source code of syscall return parameters.
273 func (r *Rets) List() string {
274         s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
275         if len(s) > 0 {
276                 s = "(" + s + ")"
277         } else if r.fnMaybeAbsent {
278                 s = "(err error)"
279         }
280         return s
281 }
282
283 // PrintList returns source code of trace printing part correspondent
284 // to syscall return values.
285 func (r *Rets) PrintList() string {
286         return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
287 }
288
289 // SetReturnValuesCode returns source code that accepts syscall return values.
290 func (r *Rets) SetReturnValuesCode() string {
291         if r.Name == "" && !r.ReturnsError {
292                 return ""
293         }
294         retvar := "r0"
295         if r.Name == "" {
296                 retvar = "r1"
297         }
298         errvar := "_"
299         if r.ReturnsError {
300                 errvar = "e1"
301         }
302         return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
303 }
304
305 func (r *Rets) useLongHandleErrorCode(retvar string) string {
306         const code = `if %s {
307                 err = errnoErr(e1)
308         }`
309         cond := retvar + " == 0"
310         if r.FailCond != "" {
311                 cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
312         }
313         return fmt.Sprintf(code, cond)
314 }
315
316 // SetErrorCode returns source code that sets return parameters.
317 func (r *Rets) SetErrorCode() string {
318         const code = `if r0 != 0 {
319                 %s = %sErrno(r0)
320         }`
321         if r.Name == "" && !r.ReturnsError {
322                 return ""
323         }
324         if r.Name == "" {
325                 return r.useLongHandleErrorCode("r1")
326         }
327         if r.Type == "error" {
328                 return fmt.Sprintf(code, r.Name, syscalldot())
329         }
330         s := ""
331         switch {
332         case r.Type[0] == '*':
333                 s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
334         case r.Type == "bool":
335                 s = fmt.Sprintf("%s = r0 != 0", r.Name)
336         default:
337                 s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
338         }
339         if !r.ReturnsError {
340                 return s
341         }
342         return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
343 }
344
345 // Fn describes syscall function.
346 type Fn struct {
347         Name        string
348         Params      []*Param
349         Rets        *Rets
350         PrintTrace  bool
351         dllname     string
352         dllfuncname string
353         src         string
354         // TODO: get rid of this field and just use parameter index instead
355         curTmpVarIdx int // insure tmp variables have uniq names
356 }
357
358 // extractParams parses s to extract function parameters.
359 func extractParams(s string, f *Fn) ([]*Param, error) {
360         s = trim(s)
361         if s == "" {
362                 return nil, nil
363         }
364         a := strings.Split(s, ",")
365         ps := make([]*Param, len(a))
366         for i := range ps {
367                 s2 := trim(a[i])
368                 b := strings.Split(s2, " ")
369                 if len(b) != 2 {
370                         b = strings.Split(s2, "\t")
371                         if len(b) != 2 {
372                                 return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
373                         }
374                 }
375                 ps[i] = &Param{
376                         Name:      trim(b[0]),
377                         Type:      trim(b[1]),
378                         fn:        f,
379                         tmpVarIdx: -1,
380                 }
381         }
382         return ps, nil
383 }
384
385 // extractSection extracts text out of string s starting after start
386 // and ending just before end. found return value will indicate success,
387 // and prefix, body and suffix will contain correspondent parts of string s.
388 func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
389         s = trim(s)
390         if strings.HasPrefix(s, string(start)) {
391                 // no prefix
392                 body = s[1:]
393         } else {
394                 a := strings.SplitN(s, string(start), 2)
395                 if len(a) != 2 {
396                         return "", "", s, false
397                 }
398                 prefix = a[0]
399                 body = a[1]
400         }
401         a := strings.SplitN(body, string(end), 2)
402         if len(a) != 2 {
403                 return "", "", "", false
404         }
405         return prefix, a[0], a[1], true
406 }
407
408 // newFn parses string s and return created function Fn.
409 func newFn(s string) (*Fn, error) {
410         s = trim(s)
411         f := &Fn{
412                 Rets:       &Rets{},
413                 src:        s,
414                 PrintTrace: *printTraceFlag,
415         }
416         // function name and args
417         prefix, body, s, found := extractSection(s, '(', ')')
418         if !found || prefix == "" {
419                 return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
420         }
421         f.Name = prefix
422         var err error
423         f.Params, err = extractParams(body, f)
424         if err != nil {
425                 return nil, err
426         }
427         // return values
428         _, body, s, found = extractSection(s, '(', ')')
429         if found {
430                 r, err := extractParams(body, f)
431                 if err != nil {
432                         return nil, err
433                 }
434                 switch len(r) {
435                 case 0:
436                 case 1:
437                         if r[0].IsError() {
438                                 f.Rets.ReturnsError = true
439                         } else {
440                                 f.Rets.Name = r[0].Name
441                                 f.Rets.Type = r[0].Type
442                         }
443                 case 2:
444                         if !r[1].IsError() {
445                                 return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
446                         }
447                         f.Rets.ReturnsError = true
448                         f.Rets.Name = r[0].Name
449                         f.Rets.Type = r[0].Type
450                 default:
451                         return nil, errors.New("Too many return values in \"" + f.src + "\"")
452                 }
453         }
454         // fail condition
455         _, body, s, found = extractSection(s, '[', ']')
456         if found {
457                 f.Rets.FailCond = body
458         }
459         // dll and dll function names
460         s = trim(s)
461         if s == "" {
462                 return f, nil
463         }
464         if !strings.HasPrefix(s, "=") {
465                 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
466         }
467         s = trim(s[1:])
468         a := strings.Split(s, ".")
469         switch len(a) {
470         case 1:
471                 f.dllfuncname = a[0]
472         case 2:
473                 f.dllname = a[0]
474                 f.dllfuncname = a[1]
475         default:
476                 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
477         }
478         if n := f.dllfuncname; strings.HasSuffix(n, "?") {
479                 f.dllfuncname = n[:len(n)-1]
480                 f.Rets.fnMaybeAbsent = true
481         }
482         return f, nil
483 }
484
485 // DLLName returns DLL name for function f.
486 func (f *Fn) DLLName() string {
487         if f.dllname == "" {
488                 return "kernel32"
489         }
490         return f.dllname
491 }
492
493 // DLLName returns DLL function name for function f.
494 func (f *Fn) DLLFuncName() string {
495         if f.dllfuncname == "" {
496                 return f.Name
497         }
498         return f.dllfuncname
499 }
500
501 // ParamList returns source code for function f parameters.
502 func (f *Fn) ParamList() string {
503         return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
504 }
505
506 // HelperParamList returns source code for helper function f parameters.
507 func (f *Fn) HelperParamList() string {
508         return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
509 }
510
511 // ParamPrintList returns source code of trace printing part correspondent
512 // to syscall input parameters.
513 func (f *Fn) ParamPrintList() string {
514         return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
515 }
516
517 // ParamCount return number of syscall parameters for function f.
518 func (f *Fn) ParamCount() int {
519         n := 0
520         for _, p := range f.Params {
521                 n += len(p.SyscallArgList())
522         }
523         return n
524 }
525
526 // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
527 // to use. It returns parameter count for correspondent SyscallX function.
528 func (f *Fn) SyscallParamCount() int {
529         n := f.ParamCount()
530         switch {
531         case n <= 3:
532                 return 3
533         case n <= 6:
534                 return 6
535         case n <= 9:
536                 return 9
537         case n <= 12:
538                 return 12
539         case n <= 15:
540                 return 15
541         default:
542                 panic("too many arguments to system call")
543         }
544 }
545
546 // Syscall determines which SyscallX function to use for function f.
547 func (f *Fn) Syscall() string {
548         c := f.SyscallParamCount()
549         if c == 3 {
550                 return syscalldot() + "Syscall"
551         }
552         return syscalldot() + "Syscall" + strconv.Itoa(c)
553 }
554
555 // SyscallParamList returns source code for SyscallX parameters for function f.
556 func (f *Fn) SyscallParamList() string {
557         a := make([]string, 0)
558         for _, p := range f.Params {
559                 a = append(a, p.SyscallArgList()...)
560         }
561         for len(a) < f.SyscallParamCount() {
562                 a = append(a, "0")
563         }
564         return strings.Join(a, ", ")
565 }
566
567 // HelperCallParamList returns source code of call into function f helper.
568 func (f *Fn) HelperCallParamList() string {
569         a := make([]string, 0, len(f.Params))
570         for _, p := range f.Params {
571                 s := p.Name
572                 if p.Type == "string" {
573                         s = p.tmpVar()
574                 }
575                 a = append(a, s)
576         }
577         return strings.Join(a, ", ")
578 }
579
580 // MaybeAbsent returns source code for handling functions that are possibly unavailable.
581 func (p *Fn) MaybeAbsent() string {
582         if !p.Rets.fnMaybeAbsent {
583                 return ""
584         }
585         const code = `%[1]s = proc%[2]s.Find()
586         if %[1]s != nil {
587                 return
588         }`
589         errorVar := p.Rets.ErrorVarName()
590         if errorVar == "" {
591                 errorVar = "err"
592         }
593         return fmt.Sprintf(code, errorVar, p.DLLFuncName())
594 }
595
596 // IsUTF16 is true, if f is W (utf16) function. It is false
597 // for all A (ascii) functions.
598 func (f *Fn) IsUTF16() bool {
599         s := f.DLLFuncName()
600         return s[len(s)-1] == 'W'
601 }
602
603 // StrconvFunc returns name of Go string to OS string function for f.
604 func (f *Fn) StrconvFunc() string {
605         if f.IsUTF16() {
606                 return syscalldot() + "UTF16PtrFromString"
607         }
608         return syscalldot() + "BytePtrFromString"
609 }
610
611 // StrconvType returns Go type name used for OS string for f.
612 func (f *Fn) StrconvType() string {
613         if f.IsUTF16() {
614                 return "*uint16"
615         }
616         return "*byte"
617 }
618
619 // HasStringParam is true, if f has at least one string parameter.
620 // Otherwise it is false.
621 func (f *Fn) HasStringParam() bool {
622         for _, p := range f.Params {
623                 if p.Type == "string" {
624                         return true
625                 }
626         }
627         return false
628 }
629
630 // HelperName returns name of function f helper.
631 func (f *Fn) HelperName() string {
632         if !f.HasStringParam() {
633                 return f.Name
634         }
635         return "_" + f.Name
636 }
637
638 // Source files and functions.
639 type Source struct {
640         Funcs           []*Fn
641         Files           []string
642         StdLibImports   []string
643         ExternalImports []string
644 }
645
646 func (src *Source) Import(pkg string) {
647         src.StdLibImports = append(src.StdLibImports, pkg)
648         sort.Strings(src.StdLibImports)
649 }
650
651 func (src *Source) ExternalImport(pkg string) {
652         src.ExternalImports = append(src.ExternalImports, pkg)
653         sort.Strings(src.ExternalImports)
654 }
655
656 // ParseFiles parses files listed in fs and extracts all syscall
657 // functions listed in sys comments. It returns source files
658 // and functions collection *Source if successful.
659 func ParseFiles(fs []string) (*Source, error) {
660         src := &Source{
661                 Funcs: make([]*Fn, 0),
662                 Files: make([]string, 0),
663                 StdLibImports: []string{
664                         "unsafe",
665                 },
666                 ExternalImports: make([]string, 0),
667         }
668         for _, file := range fs {
669                 if err := src.ParseFile(file); err != nil {
670                         return nil, err
671                 }
672         }
673         return src, nil
674 }
675
676 // DLLs return dll names for a source set src.
677 func (src *Source) DLLs() []string {
678         uniq := make(map[string]bool)
679         r := make([]string, 0)
680         for _, f := range src.Funcs {
681                 name := f.DLLName()
682                 if _, found := uniq[name]; !found {
683                         uniq[name] = true
684                         r = append(r, name)
685                 }
686         }
687         sort.Strings(r)
688         return r
689 }
690
691 // ParseFile adds additional file path to a source set src.
692 func (src *Source) ParseFile(path string) error {
693         file, err := os.Open(path)
694         if err != nil {
695                 return err
696         }
697         defer file.Close()
698
699         s := bufio.NewScanner(file)
700         for s.Scan() {
701                 t := trim(s.Text())
702                 if len(t) < 7 {
703                         continue
704                 }
705                 if !strings.HasPrefix(t, "//sys") {
706                         continue
707                 }
708                 t = t[5:]
709                 if !(t[0] == ' ' || t[0] == '\t') {
710                         continue
711                 }
712                 f, err := newFn(t[1:])
713                 if err != nil {
714                         return err
715                 }
716                 src.Funcs = append(src.Funcs, f)
717         }
718         if err := s.Err(); err != nil {
719                 return err
720         }
721         src.Files = append(src.Files, path)
722         sort.Slice(src.Funcs, func(i, j int) bool {
723                 fi, fj := src.Funcs[i], src.Funcs[j]
724                 if fi.DLLName() == fj.DLLName() {
725                         return fi.DLLFuncName() < fj.DLLFuncName()
726                 }
727                 return fi.DLLName() < fj.DLLName()
728         })
729
730         // get package name
731         fset := token.NewFileSet()
732         _, err = file.Seek(0, 0)
733         if err != nil {
734                 return err
735         }
736         pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
737         if err != nil {
738                 return err
739         }
740         packageName = pkg.Name.Name
741
742         return nil
743 }
744
745 // IsStdRepo reports whether src is part of standard library.
746 func (src *Source) IsStdRepo() (bool, error) {
747         if len(src.Files) == 0 {
748                 return false, errors.New("no input files provided")
749         }
750         abspath, err := filepath.Abs(src.Files[0])
751         if err != nil {
752                 return false, err
753         }
754         goroot := runtime.GOROOT()
755         if runtime.GOOS == "windows" {
756                 abspath = strings.ToLower(abspath)
757                 goroot = strings.ToLower(goroot)
758         }
759         sep := string(os.PathSeparator)
760         if !strings.HasSuffix(goroot, sep) {
761                 goroot += sep
762         }
763         return strings.HasPrefix(abspath, goroot), nil
764 }
765
766 // Generate output source file from a source set src.
767 func (src *Source) Generate(w io.Writer) error {
768         const (
769                 pkgStd         = iota // any package in std library
770                 pkgXSysWindows        // x/sys/windows package
771                 pkgOther
772         )
773         isStdRepo, err := src.IsStdRepo()
774         if err != nil {
775                 return err
776         }
777         var pkgtype int
778         switch {
779         case isStdRepo:
780                 pkgtype = pkgStd
781         case packageName == "windows":
782                 // TODO: this needs better logic than just using package name
783                 pkgtype = pkgXSysWindows
784         default:
785                 pkgtype = pkgOther
786         }
787         if *systemDLL {
788                 switch pkgtype {
789                 case pkgStd:
790                         src.Import("internal/syscall/windows/sysdll")
791                 case pkgXSysWindows:
792                 default:
793                         src.ExternalImport("golang.org/x/sys/windows")
794                 }
795         }
796         if packageName != "syscall" {
797                 src.Import("syscall")
798         }
799         funcMap := template.FuncMap{
800                 "packagename": packagename,
801                 "syscalldot":  syscalldot,
802                 "newlazydll": func(dll string) string {
803                         arg := "\"" + dll + ".dll\""
804                         if !*systemDLL {
805                                 return syscalldot() + "NewLazyDLL(" + arg + ")"
806                         }
807                         switch pkgtype {
808                         case pkgStd:
809                                 return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
810                         case pkgXSysWindows:
811                                 return "NewLazySystemDLL(" + arg + ")"
812                         default:
813                                 return "windows.NewLazySystemDLL(" + arg + ")"
814                         }
815                 },
816         }
817         t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
818         err = t.Execute(w, src)
819         if err != nil {
820                 return errors.New("Failed to execute template: " + err.Error())
821         }
822         return nil
823 }
824
825 func usage() {
826         fmt.Fprintf(os.Stderr, "usage: mkwinsyscall [flags] [path ...]\n")
827         flag.PrintDefaults()
828         os.Exit(1)
829 }
830
831 func main() {
832         flag.Usage = usage
833         flag.Parse()
834         if len(flag.Args()) <= 0 {
835                 fmt.Fprintf(os.Stderr, "no files to parse provided\n")
836                 usage()
837         }
838
839         src, err := ParseFiles(flag.Args())
840         if err != nil {
841                 log.Fatal(err)
842         }
843
844         var buf bytes.Buffer
845         if err := src.Generate(&buf); err != nil {
846                 log.Fatal(err)
847         }
848
849         data, err := format.Source(buf.Bytes())
850         if err != nil {
851                 log.Fatal(err)
852         }
853         if *filename == "" {
854                 _, err = os.Stdout.Write(data)
855         } else {
856                 err = ioutil.WriteFile(*filename, data, 0644)
857         }
858         if err != nil {
859                 log.Fatal(err)
860         }
861 }
862
863 // TODO: use println instead to print in the following template
864 const srcTemplate = `
865
866 {{define "main"}}// Code generated by 'go generate'; DO NOT EDIT.
867
868 package {{packagename}}
869
870 import (
871 {{range .StdLibImports}}"{{.}}"
872 {{end}}
873
874 {{range .ExternalImports}}"{{.}}"
875 {{end}}
876 )
877
878 var _ unsafe.Pointer
879
880 // Do the interface allocations only once for common
881 // Errno values.
882 const (
883         errnoERROR_IO_PENDING = 997
884 )
885
886 var (
887         errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
888         errERROR_EINVAL error     = {{syscalldot}}EINVAL
889 )
890
891 // errnoErr returns common boxed Errno values, to prevent
892 // allocations at runtime.
893 func errnoErr(e {{syscalldot}}Errno) error {
894         switch e {
895         case 0:
896                 return errERROR_EINVAL
897         case errnoERROR_IO_PENDING:
898                 return errERROR_IO_PENDING
899         }
900         // TODO: add more here, after collecting data on the common
901         // error values see on Windows. (perhaps when running
902         // all.bat?)
903         return e
904 }
905
906 var (
907 {{template "dlls" .}}
908 {{template "funcnames" .}})
909 {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
910 {{end}}
911
912 {{/* help functions */}}
913
914 {{define "dlls"}}{{range .DLLs}}        mod{{.}} = {{newlazydll .}}
915 {{end}}{{end}}
916
917 {{define "funcnames"}}{{range .Funcs}}  proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
918 {{end}}{{end}}
919
920 {{define "helperbody"}}
921 func {{.Name}}({{.ParamList}}) {{template "results" .}}{
922 {{template "helpertmpvars" .}}  return {{.HelperName}}({{.HelperCallParamList}})
923 }
924 {{end}}
925
926 {{define "funcbody"}}
927 func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
928 {{template "maybeabsent" .}}    {{template "tmpvars" .}}        {{template "syscall" .}}        {{template "tmpvarsreadback" .}}
929 {{template "seterror" .}}{{template "printtrace" .}}    return
930 }
931 {{end}}
932
933 {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}     {{.TmpVarHelperCode}}
934 {{end}}{{end}}{{end}}
935
936 {{define "maybeabsent"}}{{if .MaybeAbsent}}{{.MaybeAbsent}}
937 {{end}}{{end}}
938
939 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}}
940 {{end}}{{end}}{{end}}
941
942 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
943
944 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
945
946 {{define "tmpvarsreadback"}}{{range .Params}}{{if .TmpVarReadbackCode}}
947 {{.TmpVarReadbackCode}}{{end}}{{end}}{{end}}
948
949 {{define "seterror"}}{{if .Rets.SetErrorCode}}  {{.Rets.SetErrorCode}}
950 {{end}}{{end}}
951
952 {{define "printtrace"}}{{if .PrintTrace}}       print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
953 {{end}}{{end}}
954
955 `