]> Sergey Matveev's repositories - tofuproxy.git/blob - trip.go
-mod=vendor if vendor/ exists
[tofuproxy.git] / trip.go
1 // tofuproxy -- flexible HTTP/HTTPS proxy, TLS terminator, X.509 TOFU
2 //              manager, WARC/geminispace browser
3 // Copyright (C) 2021-2024 Sergey Matveev <stargrave@stargrave.org>
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, version 3 of the License.
8 //
9 // This program is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 // GNU General Public License for more details.
13 //
14 // You should have received a copy of the GNU General Public License
15 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 package tofuproxy
18
19 import (
20         "fmt"
21         "io"
22         "log"
23         "net"
24         "net/http"
25         "strings"
26         "time"
27
28         "github.com/dustin/go-humanize"
29         "go.stargrave.org/tofuproxy/caches"
30         "go.stargrave.org/tofuproxy/fifos"
31         "go.stargrave.org/tofuproxy/rounds"
32         ttls "go.stargrave.org/tofuproxy/tls"
33 )
34
35 var (
36         transport = http.Transport{
37                 DialContext: (&net.Dialer{
38                         Timeout:   time.Minute,
39                         KeepAlive: time.Minute,
40                 }).DialContext,
41                 MaxIdleConns:        http.DefaultTransport.(*http.Transport).MaxIdleConns,
42                 IdleConnTimeout:     http.DefaultTransport.(*http.Transport).IdleConnTimeout * 2,
43                 TLSHandshakeTimeout: time.Minute,
44                 DialTLSContext:      ttls.DialTLS,
45                 ForceAttemptHTTP2:   true,
46         }
47         proxyHeaders = map[string]struct{}{
48                 "Location":       {},
49                 "Content-Type":   {},
50                 "Content-Length": {},
51         }
52 )
53
54 type Round func(
55         host string,
56         resp *http.Response,
57         w http.ResponseWriter,
58         req *http.Request,
59 ) (bool, error)
60
61 func roundTrip(w http.ResponseWriter, req *http.Request) {
62         defer req.Body.Close()
63         fifos.LogReq <- fmt.Sprintf("%s %s", req.Method, req.URL)
64         host := strings.TrimSuffix(req.URL.Host, ":443")
65         for _, round := range []Round{
66                 rounds.RoundGemini,
67                 rounds.RoundWARC,
68                 rounds.RoundDenySpy,
69                 rounds.RoundRedditOld,
70                 rounds.RoundHabrImage,
71         } {
72                 if cont, _ := round(host, nil, w, req); !cont {
73                         return
74                 }
75         }
76
77         reqFlags := []string{}
78         unauthorized := false
79
80         caches.HTTPAuthCacheM.RLock()
81         if creds, ok := caches.HTTPAuthCache[req.URL.Host]; ok {
82                 req.SetBasicAuth(creds[0], creds[1])
83                 unauthorized = true
84         }
85         caches.HTTPAuthCacheM.RUnlock()
86
87 Retry:
88         resp, err := transport.RoundTrip(req)
89         if err != nil {
90                 fifos.LogErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
91                 http.Error(w, err.Error(), http.StatusBadGateway)
92                 return
93         }
94
95         if resp.StatusCode == http.StatusUnauthorized {
96                 resp.Body.Close()
97                 caches.HTTPAuthCacheM.Lock()
98                 if unauthorized {
99                         delete(caches.HTTPAuthCache, req.URL.Host)
100                 } else {
101                         unauthorized = true
102                 }
103                 fifos.LogVarious <- fmt.Sprintf(
104                         "%s %s\tHTTP authorization required", req.Method, req.URL.Host,
105                 )
106                 user, pass, err := authDialog(host, resp.Header.Get("WWW-Authenticate"))
107                 if err != nil {
108                         caches.HTTPAuthCacheM.Unlock()
109                         fifos.LogErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
110                         http.Error(w, err.Error(), http.StatusInternalServerError)
111                         return
112                 }
113                 caches.HTTPAuthCache[req.URL.Host] = [2]string{user, pass}
114                 caches.HTTPAuthCacheM.Unlock()
115                 req.SetBasicAuth(user, pass)
116                 fifos.LogHTTPAuth <- fmt.Sprintf("%s %s\t%s", req.Method, req.URL, user)
117                 goto Retry
118         }
119         if unauthorized {
120                 reqFlags = append(reqFlags, "auth")
121         }
122         if resp.TLS != nil && resp.TLS.NegotiatedProtocol != "" {
123                 reqFlags = append(reqFlags, resp.TLS.NegotiatedProtocol)
124         }
125
126         for k, vs := range resp.Header {
127                 if _, ok := proxyHeaders[k]; ok {
128                         continue
129                 }
130                 for _, v := range vs {
131                         w.Header().Add(k, v)
132                 }
133         }
134
135         for _, round := range []Round{
136                 rounds.RoundDenyFonts,
137                 rounds.RoundTranscodeWebP,
138                 rounds.RoundTranscodeJXL,
139                 rounds.RoundTranscodeAVIF,
140                 rounds.RoundRedirectHTML,
141         } {
142                 cont, err := round(host, resp, w, req)
143                 if err != nil {
144                         http.Error(w, err.Error(), http.StatusBadGateway)
145                         return
146                 }
147                 if !cont {
148                         return
149                 }
150         }
151
152         for h := range proxyHeaders {
153                 if v := resp.Header.Get(h); v != "" {
154                         w.Header().Add(h, v)
155                 }
156         }
157         w.WriteHeader(resp.StatusCode)
158         n, err := io.Copy(w, resp.Body)
159         if err != nil {
160                 log.Printf("Error during %s: %+v\n", req.URL, err)
161         }
162         resp.Body.Close()
163         msg := fmt.Sprintf(
164                 "%s %s\t%s\t%s\t%s\t%s",
165                 req.Method, req.URL,
166                 resp.Status,
167                 resp.Header.Get("Content-Type"),
168                 humanize.IBytes(uint64(n)),
169                 strings.Join(reqFlags, ","),
170         )
171         if resp.StatusCode == http.StatusOK {
172                 fifos.LogOK <- msg
173         } else {
174                 fifos.LogNonOK <- msg
175         }
176 }