]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - trip.go
HTTP authorization
[tofuproxy.git] / trip.go
diff --git a/trip.go b/trip.go
index 3ad6eb178f1dca2c24d738c883ae7481d5b3a65e..fa344457f97721c532c28c6c2b2ca15599dddd1e 100644 (file)
--- a/trip.go
+++ b/trip.go
@@ -71,6 +71,9 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                }
        }
 
+       reqFlags := []string{}
+       unauthorized := false
+Retry:
        resp, err := transport.RoundTrip(req)
        if err != nil {
                fifos.SinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
@@ -78,6 +81,39 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                return
        }
 
+       if resp.StatusCode == http.StatusUnauthorized {
+               resp.Body.Close()
+               authCacheM.Lock()
+               if unauthorized {
+                       delete(authCache, req.URL.Host)
+               } else {
+                       unauthorized = true
+                       if creds, ok := authCache[req.URL.Host]; ok {
+                               authCacheM.Unlock()
+                               req.SetBasicAuth(creds[0], creds[1])
+                               goto Retry
+                       }
+               }
+               fifos.SinkOther <- fmt.Sprintf("%s\tauthorization required", req.URL.Host)
+               user, pass, err := authDialog(host, resp.Header.Get("WWW-Authenticate"))
+               if err != nil {
+                       authCacheM.Unlock()
+                       fifos.SinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
+                       http.Error(w, err.Error(), http.StatusInternalServerError)
+                       return
+               }
+               authCache[req.URL.Host] = [2]string{user, pass}
+               authCacheM.Unlock()
+               req.SetBasicAuth(user, pass)
+               goto Retry
+       }
+       if unauthorized {
+               reqFlags = append(reqFlags, "auth")
+       }
+       if resp.TLS != nil && resp.TLS.NegotiatedProtocol != "" {
+               reqFlags = append(reqFlags, resp.TLS.NegotiatedProtocol)
+       }
+
        for k, vs := range resp.Header {
                if _, ok := proxyHeaders[k]; ok {
                        continue
@@ -116,12 +152,13 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
        }
        resp.Body.Close()
        msg := fmt.Sprintf(
-               "%s %s\t%s\t%s\t%s",
+               "%s %s\t%s\t%s\t%s\t%s",
                req.Method,
                req.URL.String(),
                resp.Status,
                resp.Header.Get("Content-Type"),
                humanize.IBytes(uint64(n)),
+               strings.Join(reqFlags, ","),
        )
        if resp.StatusCode == http.StatusOK {
                fifos.SinkOK <- msg