]> Sergey Matveev's repositories - tofuproxy.git/blob - trip.go
Set BasicAuth before request
[tofuproxy.git] / trip.go
1 /*
2 tofuproxy -- HTTP proxy with TLS certificates management
3 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package tofuproxy
19
20 import (
21         "fmt"
22         "io"
23         "log"
24         "net"
25         "net/http"
26         "strings"
27         "time"
28
29         "github.com/dustin/go-humanize"
30         "go.stargrave.org/tofuproxy/fifos"
31         "go.stargrave.org/tofuproxy/rounds"
32 )
33
34 var (
35         transport = http.Transport{
36                 DialContext: (&net.Dialer{
37                         Timeout:   time.Minute,
38                         KeepAlive: time.Minute,
39                 }).DialContext,
40                 MaxIdleConns:        http.DefaultTransport.(*http.Transport).MaxIdleConns,
41                 IdleConnTimeout:     http.DefaultTransport.(*http.Transport).IdleConnTimeout * 2,
42                 TLSHandshakeTimeout: time.Minute,
43                 DialTLSContext:      dialTLS,
44                 ForceAttemptHTTP2:   true,
45         }
46         proxyHeaders = map[string]struct{}{
47                 "Location":       {},
48                 "Content-Type":   {},
49                 "Content-Length": {},
50         }
51 )
52
53 type Round func(
54         host string,
55         resp *http.Response,
56         w http.ResponseWriter,
57         req *http.Request,
58 ) (bool, error)
59
60 func roundTrip(w http.ResponseWriter, req *http.Request) {
61         fifos.SinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
62         host := strings.TrimSuffix(req.URL.Host, ":443")
63         for _, round := range []Round{
64                 rounds.RoundNoHead,
65                 rounds.RoundDenySpy,
66                 rounds.RoundRedditOld,
67                 rounds.RoundHabrImage,
68         } {
69                 if cont, _ := round(host, nil, w, req); !cont {
70                         return
71                 }
72         }
73
74         reqFlags := []string{}
75         unauthorized := false
76
77         authCacheM.Lock()
78         if creds, ok := authCache[req.URL.Host]; ok {
79                 req.SetBasicAuth(creds[0], creds[1])
80                 unauthorized = true
81         }
82         authCacheM.Unlock()
83
84 Retry:
85         resp, err := transport.RoundTrip(req)
86         if err != nil {
87                 fifos.SinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
88                 http.Error(w, err.Error(), http.StatusBadGateway)
89                 return
90         }
91
92         if resp.StatusCode == http.StatusUnauthorized {
93                 resp.Body.Close()
94                 authCacheM.Lock()
95                 if unauthorized {
96                         delete(authCache, req.URL.Host)
97                 } else {
98                         unauthorized = true
99                 }
100                 fifos.SinkOther <- fmt.Sprintf("%s\tauthorization required", req.URL.Host)
101                 user, pass, err := authDialog(host, resp.Header.Get("WWW-Authenticate"))
102                 if err != nil {
103                         authCacheM.Unlock()
104                         fifos.SinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
105                         http.Error(w, err.Error(), http.StatusInternalServerError)
106                         return
107                 }
108                 authCache[req.URL.Host] = [2]string{user, pass}
109                 authCacheM.Unlock()
110                 req.SetBasicAuth(user, pass)
111                 goto Retry
112         }
113         if unauthorized {
114                 reqFlags = append(reqFlags, "auth")
115         }
116         if resp.TLS != nil && resp.TLS.NegotiatedProtocol != "" {
117                 reqFlags = append(reqFlags, resp.TLS.NegotiatedProtocol)
118         }
119
120         for k, vs := range resp.Header {
121                 if _, ok := proxyHeaders[k]; ok {
122                         continue
123                 }
124                 for _, v := range vs {
125                         w.Header().Add(k, v)
126                 }
127         }
128
129         for _, round := range []Round{
130                 rounds.RoundDenyFonts,
131                 rounds.RoundTranscodeWebP,
132                 rounds.RoundTranscodeJXL,
133                 rounds.RoundTranscodeAVIF,
134                 rounds.RoundRedirectHTML,
135         } {
136                 cont, err := round(host, resp, w, req)
137                 if err != nil {
138                         http.Error(w, err.Error(), http.StatusBadGateway)
139                         return
140                 }
141                 if !cont {
142                         return
143                 }
144         }
145
146         for h := range proxyHeaders {
147                 if v := resp.Header.Get(h); v != "" {
148                         w.Header().Add(h, v)
149                 }
150         }
151         w.WriteHeader(resp.StatusCode)
152         n, err := io.Copy(w, resp.Body)
153         if err != nil {
154                 log.Printf("Error during %s: %+v\n", req.URL, err)
155         }
156         resp.Body.Close()
157         msg := fmt.Sprintf(
158                 "%s %s\t%s\t%s\t%s\t%s",
159                 req.Method,
160                 req.URL.String(),
161                 resp.Status,
162                 resp.Header.Get("Content-Type"),
163                 humanize.IBytes(uint64(n)),
164                 strings.Join(reqFlags, ","),
165         )
166         if resp.StatusCode == http.StatusOK {
167                 fifos.SinkOK <- msg
168         } else {
169                 fifos.SinkOther <- msg
170         }
171 }