]> Sergey Matveev's repositories - tofuproxy.git/blob - trip.go
Do not forget about body closing
[tofuproxy.git] / trip.go
1 /*
2 tofuproxy -- HTTP proxy with TLS certificates management
3 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package tofuproxy
19
20 import (
21         "fmt"
22         "io"
23         "log"
24         "net"
25         "net/http"
26         "strings"
27         "time"
28
29         "github.com/dustin/go-humanize"
30         "go.stargrave.org/tofuproxy/caches"
31         "go.stargrave.org/tofuproxy/fifos"
32         "go.stargrave.org/tofuproxy/rounds"
33 )
34
35 var (
36         transport = http.Transport{
37                 DialContext: (&net.Dialer{
38                         Timeout:   time.Minute,
39                         KeepAlive: time.Minute,
40                 }).DialContext,
41                 MaxIdleConns:        http.DefaultTransport.(*http.Transport).MaxIdleConns,
42                 IdleConnTimeout:     http.DefaultTransport.(*http.Transport).IdleConnTimeout * 2,
43                 TLSHandshakeTimeout: time.Minute,
44                 DialTLSContext:      dialTLS,
45                 ForceAttemptHTTP2:   true,
46         }
47         proxyHeaders = map[string]struct{}{
48                 "Location":       {},
49                 "Content-Type":   {},
50                 "Content-Length": {},
51         }
52 )
53
54 type Round func(
55         host string,
56         resp *http.Response,
57         w http.ResponseWriter,
58         req *http.Request,
59 ) (bool, error)
60
61 func roundTrip(w http.ResponseWriter, req *http.Request) {
62         defer req.Body.Close()
63         fifos.LogReq <- fmt.Sprintf("%s %s", req.Method, req.URL)
64         host := strings.TrimSuffix(req.URL.Host, ":443")
65         for _, round := range []Round{
66                 rounds.RoundNoHead,
67                 rounds.RoundDenySpy,
68                 rounds.RoundRedditOld,
69                 rounds.RoundHabrImage,
70         } {
71                 if cont, _ := round(host, nil, w, req); !cont {
72                         return
73                 }
74         }
75
76         reqFlags := []string{}
77         unauthorized := false
78
79         caches.HTTPAuthCacheM.RLock()
80         if creds, ok := caches.HTTPAuthCache[req.URL.Host]; ok {
81                 req.SetBasicAuth(creds[0], creds[1])
82                 unauthorized = true
83         }
84         caches.HTTPAuthCacheM.RUnlock()
85
86 Retry:
87         resp, err := transport.RoundTrip(req)
88         if err != nil {
89                 fifos.LogErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
90                 http.Error(w, err.Error(), http.StatusBadGateway)
91                 return
92         }
93
94         if resp.StatusCode == http.StatusUnauthorized {
95                 resp.Body.Close()
96                 caches.HTTPAuthCacheM.Lock()
97                 if unauthorized {
98                         delete(caches.HTTPAuthCache, req.URL.Host)
99                 } else {
100                         unauthorized = true
101                 }
102                 fifos.LogVarious <- fmt.Sprintf(
103                         "%s %s\tHTTP authorization required", req.Method, req.URL.Host,
104                 )
105                 user, pass, err := authDialog(host, resp.Header.Get("WWW-Authenticate"))
106                 if err != nil {
107                         caches.HTTPAuthCacheM.Unlock()
108                         fifos.LogErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
109                         http.Error(w, err.Error(), http.StatusInternalServerError)
110                         return
111                 }
112                 caches.HTTPAuthCache[req.URL.Host] = [2]string{user, pass}
113                 caches.HTTPAuthCacheM.Unlock()
114                 req.SetBasicAuth(user, pass)
115                 fifos.LogHTTPAuth <- fmt.Sprintf("%s %s\t%s", req.Method, req.URL, user)
116                 goto Retry
117         }
118         if unauthorized {
119                 reqFlags = append(reqFlags, "auth")
120         }
121         if resp.TLS != nil && resp.TLS.NegotiatedProtocol != "" {
122                 reqFlags = append(reqFlags, resp.TLS.NegotiatedProtocol)
123         }
124
125         for k, vs := range resp.Header {
126                 if _, ok := proxyHeaders[k]; ok {
127                         continue
128                 }
129                 for _, v := range vs {
130                         w.Header().Add(k, v)
131                 }
132         }
133
134         for _, round := range []Round{
135                 rounds.RoundDenyFonts,
136                 rounds.RoundTranscodeWebP,
137                 rounds.RoundTranscodeJXL,
138                 rounds.RoundTranscodeAVIF,
139                 rounds.RoundRedirectHTML,
140         } {
141                 cont, err := round(host, resp, w, req)
142                 if err != nil {
143                         http.Error(w, err.Error(), http.StatusBadGateway)
144                         return
145                 }
146                 if !cont {
147                         return
148                 }
149         }
150
151         for h := range proxyHeaders {
152                 if v := resp.Header.Get(h); v != "" {
153                         w.Header().Add(h, v)
154                 }
155         }
156         w.WriteHeader(resp.StatusCode)
157         n, err := io.Copy(w, resp.Body)
158         if err != nil {
159                 log.Printf("Error during %s: %+v\n", req.URL, err)
160         }
161         resp.Body.Close()
162         msg := fmt.Sprintf(
163                 "%s %s\t%s\t%s\t%s\t%s",
164                 req.Method, req.URL,
165                 resp.Status,
166                 resp.Header.Get("Content-Type"),
167                 humanize.IBytes(uint64(n)),
168                 strings.Join(reqFlags, ","),
169         )
170         if resp.StatusCode == http.StatusOK {
171                 fifos.LogOK <- msg
172         } else {
173                 fifos.LogNonOK <- msg
174         }
175 }