// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // authtest is a diagnostic tool for implementations of the GOAUTH protocol // described in https://golang.org/issue/26232. // // It accepts a single URL as an argument, and executes the GOAUTH protocol to // fetch and display the headers for that URL. // // CAUTION: authtest logs the GOAUTH responses, which may include user // credentials, to stderr. Do not post its output unless you are certain that // all of the credentials involved are fake! package main import ( "bufio" "bytes" "flag" "fmt" "io" "log" "net/http" "net/textproto" "net/url" "os" "os/exec" "path/filepath" "strings" ) var v = flag.Bool("v", false, "if true, log GOAUTH responses to stderr") func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) flag.Parse() args := flag.Args() if len(args) != 1 { log.Fatalf("usage: [GOAUTH=CMD...] %s URL", filepath.Base(os.Args[0])) } resp := try(args[0], nil) if resp.StatusCode == http.StatusOK { return } resp = try(args[0], resp) if resp.StatusCode != http.StatusOK { os.Exit(1) } } func try(url string, prev *http.Response) *http.Response { req := new(http.Request) if prev != nil { *req = *prev.Request } else { var err error req, err = http.NewRequest("HEAD", os.Args[1], nil) if err != nil { log.Fatal(err) } } goauth: for _, argList := range strings.Split(os.Getenv("GOAUTH"), ";") { // TODO(golang.org/issue/26849): If we escape quoted strings in GOFLAGS, use // the same quoting here. args := strings.Split(argList, " ") if len(args) == 0 || args[0] == "" { log.Fatalf("invalid or empty command in GOAUTH") } creds, err := getCreds(args, prev) if err != nil { log.Fatal(err) } for _, c := range creds { if c.Apply(req) { fmt.Fprintf(os.Stderr, "# request to %s\n", req.URL) fmt.Fprintf(os.Stderr, "%s %s %s\n", req.Method, req.URL, req.Proto) req.Header.Write(os.Stderr) fmt.Fprintln(os.Stderr) break goauth } } } resp, err := http.DefaultClient.Do(req) if err != nil { log.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode < 400 || resp.StatusCode > 500 { log.Fatalf("unexpected status: %v", resp.Status) } fmt.Fprintf(os.Stderr, "# response from %s\n", resp.Request.URL) formatHead(os.Stderr, resp) return resp } func formatHead(out io.Writer, resp *http.Response) { fmt.Fprintf(out, "%s %s\n", resp.Proto, resp.Status) if err := resp.Header.Write(out); err != nil { log.Fatal(err) } fmt.Fprintln(out) } type Cred struct { URLPrefixes []*url.URL Header http.Header } func (c Cred) Apply(req *http.Request) bool { if req.URL == nil { return false } ok := false for _, prefix := range c.URLPrefixes { if prefix.Host == req.URL.Host && (req.URL.Path == prefix.Path || (strings.HasPrefix(req.URL.Path, prefix.Path) && (strings.HasSuffix(prefix.Path, "/") || req.URL.Path[len(prefix.Path)] == '/'))) { ok = true break } } if !ok { return false } for k, vs := range c.Header { req.Header.Del(k) for _, v := range vs { req.Header.Add(k, v) } } return true } func (c Cred) String() string { var buf strings.Builder for _, u := range c.URLPrefixes { fmt.Fprintln(&buf, u) } buf.WriteString("\n") c.Header.Write(&buf) buf.WriteString("\n") return buf.String() } func getCreds(args []string, resp *http.Response) ([]Cred, error) { cmd := exec.Command(args[0], args[1:]...) cmd.Stderr = os.Stderr if resp != nil { u := *resp.Request.URL u.RawQuery = "" cmd.Args = append(cmd.Args, u.String()) } var head strings.Builder if resp != nil { formatHead(&head, resp) } cmd.Stdin = strings.NewReader(head.String()) fmt.Fprintf(os.Stderr, "# %s\n", strings.Join(cmd.Args, " ")) out, err := cmd.Output() if err != nil { return nil, fmt.Errorf("%s: %v", strings.Join(cmd.Args, " "), err) } os.Stderr.Write(out) os.Stderr.WriteString("\n") var creds []Cred r := textproto.NewReader(bufio.NewReader(bytes.NewReader(out))) line := 0 readLoop: for { var prefixes []*url.URL for { prefix, err := r.ReadLine() if err == io.EOF { if len(prefixes) > 0 { return nil, fmt.Errorf("line %d: %v", line, io.ErrUnexpectedEOF) } break readLoop } line++ if prefix == "" { if len(prefixes) == 0 { return nil, fmt.Errorf("line %d: unexpected newline", line) } break } u, err := url.Parse(prefix) if err != nil { return nil, fmt.Errorf("line %d: malformed URL: %v", line, err) } if u.Scheme != "https" { return nil, fmt.Errorf("line %d: non-HTTPS URL %q", line, prefix) } if len(u.RawQuery) > 0 { return nil, fmt.Errorf("line %d: unexpected query string in URL %q", line, prefix) } if len(u.Fragment) > 0 { return nil, fmt.Errorf("line %d: unexpected fragment in URL %q", line, prefix) } prefixes = append(prefixes, u) } header, err := r.ReadMIMEHeader() if err != nil { return nil, fmt.Errorf("headers at line %d: %v", line, err) } if len(header) > 0 { creds = append(creds, Cred{ URLPrefixes: prefixes, Header: http.Header(header), }) } } return creds, nil }