]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
HTTP/2.0
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "context"
21         "crypto"
22         "crypto/tls"
23         "crypto/x509"
24         "flag"
25         "fmt"
26         "io"
27         "io/ioutil"
28         "log"
29         "net"
30         "net/http"
31         "os"
32         "os/exec"
33         "strings"
34         "time"
35
36         "github.com/dustin/go-humanize"
37         "go.cypherpunks.ru/ucspi"
38 )
39
40 var (
41         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
42         caCert        *x509.Certificate
43         caPrv         crypto.PrivateKey
44         transport     = http.Transport{
45                 DialTLSContext:    dialTLS,
46                 ForceAttemptHTTP2: true,
47         }
48         sessionCache = tls.NewLRUClientSessionCache(1024)
49
50         CmdDWebP = "dwebp"
51         CmdDJXL  = "djxl"
52 )
53
54 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
55         host := strings.TrimSuffix(addr, ":443")
56         cfg := tls.Config{
57                 VerifyPeerCertificate: func(
58                         rawCerts [][]byte,
59                         verifiedChains [][]*x509.Certificate,
60                 ) error {
61                         return verifyCert(host, nil, rawCerts, verifiedChains)
62                 },
63                 ClientSessionCache: sessionCache,
64                 NextProtos:         []string{"h2", "http/1.1"},
65         }
66         conn, dialErr := tls.Dial(network, addr, &cfg)
67         if dialErr != nil {
68                 if _, ok := dialErr.(ErrRejected); ok {
69                         return nil, dialErr
70                 }
71                 cfg.InsecureSkipVerify = true
72                 cfg.VerifyPeerCertificate = func(
73                         rawCerts [][]byte,
74                         verifiedChains [][]*x509.Certificate,
75                 ) error {
76                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
77                 }
78                 var err error
79                 conn, err = tls.Dial(network, addr, &cfg)
80                 if err != nil {
81                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
82                         return nil, err
83                 }
84         }
85         connState := conn.ConnectionState()
86         if connState.DidResume {
87                 sinkTLS <- fmt.Sprintf(
88                         "%s\t%s %s\t%s\t%s",
89                         strings.TrimSuffix(addr, ":443"),
90                         ucspi.TLSVersion(connState.Version),
91                         tls.CipherSuiteName(connState.CipherSuite),
92                         spkiHash(connState.PeerCertificates[0]),
93                         connState.NegotiatedProtocol,
94                 )
95         }
96         return conn, nil
97 }
98
99 func roundTrip(w http.ResponseWriter, req *http.Request) {
100         if req.Method == http.MethodHead {
101                 http.Error(w, "go away", http.StatusMethodNotAllowed)
102                 return
103         }
104         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
105         host := strings.TrimSuffix(req.URL.Host, ":443")
106
107         for _, spy := range SpyDomains {
108                 if strings.HasSuffix(host, spy) {
109                         http.NotFound(w, req)
110                         sinkOther <- fmt.Sprintf(
111                                 "%s %s\t%d\tspy one",
112                                 req.Method,
113                                 req.URL.String(),
114                                 http.StatusNotFound,
115                         )
116                         return
117                 }
118         }
119
120         if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
121                 req.URL.Host = "old.reddit.com"
122                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
123                 return
124         }
125
126         resp, err := transport.RoundTrip(req)
127         if err != nil {
128                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
129                 w.WriteHeader(http.StatusBadGateway)
130                 w.Write([]byte(err.Error()))
131                 return
132         }
133
134         for k, vs := range resp.Header {
135                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
136                         continue
137                 }
138                 for _, v := range vs {
139                         w.Header().Add(k, v)
140                 }
141         }
142
143         switch resp.Header.Get("Content-Type") {
144         case "application/font-woff", "application/font-sfnt":
145                 // Those are deprecated types
146                 fallthrough
147         case "font/otf", "font/ttf", "font/woff", "font/woff2":
148                 http.NotFound(w, req)
149                 sinkOther <- fmt.Sprintf(
150                         "%s %s\t%d\tfonts are not allowed",
151                         req.Method,
152                         req.URL.String(),
153                         http.StatusNotFound,
154                 )
155                 resp.Body.Close()
156                 return
157         case "image/webp":
158                 if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
159                         // My Xombrero
160                         break
161                 }
162                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
163                 if err != nil {
164                         log.Fatalln(err)
165                 }
166                 defer tmpFd.Close()
167                 defer os.Remove(tmpFd.Name())
168                 defer resp.Body.Close()
169                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
170                         log.Printf("Error during %s: %+v\n", req.URL, err)
171                         http.Error(w, err.Error(), http.StatusBadGateway)
172                         return
173                 }
174                 tmpFd.Close()
175                 cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
176                 data, err := cmd.Output()
177                 if err != nil {
178                         http.Error(w, err.Error(), http.StatusBadGateway)
179                         return
180                 }
181                 w.Header().Add("Content-Type", "image/png")
182                 w.WriteHeader(http.StatusOK)
183                 w.Write(data)
184                 sinkOther <- fmt.Sprintf(
185                         "%s %s\t%d\tWebP transcoded to PNG",
186                         req.Method,
187                         req.URL.String(),
188                         http.StatusOK,
189                 )
190                 return
191         case "image/jxl":
192                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
193                 if err != nil {
194                         log.Fatalln(err)
195                 }
196                 defer tmpFd.Close()
197                 defer os.Remove(tmpFd.Name())
198                 defer resp.Body.Close()
199                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
200                         log.Printf("Error during %s: %+v\n", req.URL, err)
201                         http.Error(w, err.Error(), http.StatusBadGateway)
202                         return
203                 }
204                 tmpFd.Close()
205                 dstFn := tmpFd.Name() + ".png"
206                 cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
207                 err = cmd.Run()
208                 defer os.Remove(dstFn)
209                 if err != nil {
210                         http.Error(w, err.Error(), http.StatusBadGateway)
211                         return
212                 }
213                 data, err := ioutil.ReadFile(dstFn)
214                 if err != nil {
215                         http.Error(w, err.Error(), http.StatusBadGateway)
216                         return
217                 }
218                 w.Header().Add("Content-Type", "image/png")
219                 w.WriteHeader(http.StatusOK)
220                 w.Write(data)
221                 sinkOther <- fmt.Sprintf(
222                         "%s %s\t%d\tJPEG XL transcoded to PNG",
223                         req.Method,
224                         req.URL.String(),
225                         http.StatusOK,
226                 )
227                 return
228         }
229
230         if req.Method == http.MethodGet {
231                 var redirType string
232                 switch resp.StatusCode {
233                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
234                         redirType = "permanent"
235                         goto Redir
236                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
237                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
238                                 goto NoRedir
239                         }
240                         redirType = "temporary"
241                 default:
242                         goto NoRedir
243                 }
244         Redir:
245                 resp.Body.Close()
246                 w.Header().Add("Content-Type", "text/html")
247                 w.WriteHeader(http.StatusOK)
248                 location := resp.Header.Get("Location")
249                 w.Write([]byte(
250                         fmt.Sprintf(
251                                 `<html>
252 <head><title>%d %s: %s redirection</title></head>
253 <body>Redirection to <a href="%s">%s</a>
254 </body></html>`,
255                                 resp.StatusCode, http.StatusText(resp.StatusCode),
256                                 redirType, location, location,
257                         )))
258                 sinkRedir <- fmt.Sprintf(
259                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
260                 )
261                 return
262         }
263
264 NoRedir:
265         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
266                 if v := resp.Header.Get(h); v != "" {
267                         w.Header().Add(h, v)
268                 }
269         }
270         w.WriteHeader(resp.StatusCode)
271         n, err := io.Copy(w, resp.Body)
272         if err != nil {
273                 log.Printf("Error during %s: %+v\n", req.URL, err)
274         }
275         resp.Body.Close()
276         msg := fmt.Sprintf(
277                 "%s %s\t%s\t%s\t%s",
278                 req.Method,
279                 req.URL.String(),
280                 resp.Status,
281                 resp.Header.Get("Content-Type"),
282                 humanize.IBytes(uint64(n)),
283         )
284         if resp.StatusCode == http.StatusOK {
285                 sinkOK <- msg
286         } else {
287                 sinkOther <- msg
288         }
289 }
290
291 type Handler struct{}
292
293 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
294         if req.Method != http.MethodConnect {
295                 roundTrip(w, req)
296                 return
297         }
298         hj, ok := w.(http.Hijacker)
299         if !ok {
300                 log.Fatalln("no hijacking")
301         }
302         conn, _, err := hj.Hijack()
303         if err != nil {
304                 log.Fatalln(err)
305         }
306         defer conn.Close()
307         conn.Write([]byte(fmt.Sprintf(
308                 "%s %d %s\r\n\r\n",
309                 req.Proto,
310                 http.StatusOK, http.StatusText(http.StatusOK),
311         )))
312         host := strings.Split(req.Host, ":")[0]
313         hostCertsM.Lock()
314         keypair, ok := hostCerts[host]
315         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
316                 keypair = newKeypair(host, caCert, caPrv)
317                 hostCerts[host] = keypair
318         }
319         hostCertsM.Unlock()
320         tlsConn := tls.Server(conn, &tls.Config{
321                 Certificates: []tls.Certificate{{
322                         Certificate: [][]byte{keypair.cert.Raw},
323                         PrivateKey:  keypair.prv,
324                 }},
325         })
326         if err = tlsConn.Handshake(); err != nil {
327                 log.Printf("TLS error %s: %+v\n", host, err)
328                 return
329         }
330         srv := http.Server{
331                 Handler:      &HTTPSHandler{host: req.Host},
332                 TLSNextProto: tlsNextProtoS,
333         }
334         err = srv.Serve(&SingleListener{conn: tlsConn})
335         if err != nil {
336                 if _, ok := err.(AlreadyAccepted); !ok {
337                         log.Printf("TLS serve error %s: %+v\n", host, err)
338                         return
339                 }
340         }
341 }
342
343 type HTTPSHandler struct {
344         host string
345 }
346
347 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
348         req.URL.Scheme = "https"
349         req.URL.Host = h.host
350         roundTrip(w, req)
351 }
352
353 func main() {
354         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
355         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
356         bind := flag.String("bind", "[::1]:8080", "Bind address")
357         certs = flag.String("certs", "certs", "Directory with pinned certificates")
358         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
359         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
360         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
361         flag.Parse()
362         log.SetFlags(log.Lshortfile)
363         fifoInit()
364
365         var err error
366         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
367         if err != nil {
368                 log.Fatalln(err)
369         }
370         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
371         if err != nil {
372                 log.Fatalln(err)
373         }
374
375         ln, err := net.Listen("tcp", *bind)
376         if err != nil {
377                 log.Fatalln(err)
378         }
379         srv := http.Server{
380                 Handler:      &Handler{},
381                 TLSNextProto: tlsNextProtoS,
382         }
383         log.Println("listening:", *bind)
384         if err := srv.Serve(ln); err != nil {
385                 log.Fatalln(err)
386         }
387 }