]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
Хабр's full size images
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "context"
21         "crypto"
22         "crypto/tls"
23         "crypto/x509"
24         "flag"
25         "fmt"
26         "io"
27         "io/ioutil"
28         "log"
29         "net"
30         "net/http"
31         "os"
32         "os/exec"
33         "path/filepath"
34         "strings"
35         "time"
36
37         "github.com/dustin/go-humanize"
38         "go.cypherpunks.ru/ucspi"
39 )
40
41 var (
42         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
43         caCert        *x509.Certificate
44         caPrv         crypto.PrivateKey
45         transport     = http.Transport{
46                 DialTLSContext:    dialTLS,
47                 ForceAttemptHTTP2: true,
48         }
49         sessionCache = tls.NewLRUClientSessionCache(1024)
50
51         CmdDWebP = "dwebp"
52         CmdDJXL  = "djxl"
53
54         imageExts = map[string]struct{}{
55                 ".apng": {},
56                 ".avif": {},
57                 ".gif":  {},
58                 ".heic": {},
59                 ".jp2":  {},
60                 ".jpeg": {},
61                 ".jpg":  {},
62                 ".jxl":  {},
63                 ".mng":  {},
64                 ".png":  {},
65                 ".svg":  {},
66                 ".tif":  {},
67                 ".tiff": {},
68                 ".webp": {},
69         }
70 )
71
72 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
73         host := strings.TrimSuffix(addr, ":443")
74         cfg := tls.Config{
75                 VerifyPeerCertificate: func(
76                         rawCerts [][]byte,
77                         verifiedChains [][]*x509.Certificate,
78                 ) error {
79                         return verifyCert(host, nil, rawCerts, verifiedChains)
80                 },
81                 ClientSessionCache: sessionCache,
82                 NextProtos:         []string{"h2", "http/1.1"},
83         }
84         conn, dialErr := tls.Dial(network, addr, &cfg)
85         if dialErr != nil {
86                 if _, ok := dialErr.(ErrRejected); ok {
87                         return nil, dialErr
88                 }
89                 cfg.InsecureSkipVerify = true
90                 cfg.VerifyPeerCertificate = func(
91                         rawCerts [][]byte,
92                         verifiedChains [][]*x509.Certificate,
93                 ) error {
94                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
95                 }
96                 var err error
97                 conn, err = tls.Dial(network, addr, &cfg)
98                 if err != nil {
99                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
100                         return nil, err
101                 }
102         }
103         connState := conn.ConnectionState()
104         if connState.DidResume {
105                 sinkTLS <- fmt.Sprintf(
106                         "%s\t%s %s\t%s\t%s",
107                         strings.TrimSuffix(addr, ":443"),
108                         ucspi.TLSVersion(connState.Version),
109                         tls.CipherSuiteName(connState.CipherSuite),
110                         spkiHash(connState.PeerCertificates[0]),
111                         connState.NegotiatedProtocol,
112                 )
113         }
114         return conn, nil
115 }
116
117 func roundTrip(w http.ResponseWriter, req *http.Request) {
118         if req.Method == http.MethodHead {
119                 http.Error(w, "go away", http.StatusMethodNotAllowed)
120                 return
121         }
122         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
123         host := strings.TrimSuffix(req.URL.Host, ":443")
124
125         for _, spy := range SpyDomains {
126                 if strings.HasSuffix(host, spy) {
127                         http.NotFound(w, req)
128                         sinkOther <- fmt.Sprintf(
129                                 "%s %s\t%d\tspy one",
130                                 req.Method,
131                                 req.URL.String(),
132                                 http.StatusNotFound,
133                         )
134                         return
135                 }
136         }
137
138         if host == "www.reddit.com" {
139                 req.URL.Host = "old.reddit.com"
140                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
141                 return
142         }
143
144         if host == "habrastorage.org" &&
145                 strings.Contains(req.URL.Path, "/webt/") &&
146                 !strings.HasPrefix(req.URL.Path, "/webt/") {
147                 req.URL.Path = req.URL.Path[strings.Index(req.URL.Path, "/webt/"):]
148                 http.Redirect(w, req, req.URL.String(), http.StatusFound)
149                 return
150         }
151
152         resp, err := transport.RoundTrip(req)
153         if err != nil {
154                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
155                 w.WriteHeader(http.StatusBadGateway)
156                 w.Write([]byte(err.Error()))
157                 return
158         }
159
160         for k, vs := range resp.Header {
161                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
162                         continue
163                 }
164                 for _, v := range vs {
165                         w.Header().Add(k, v)
166                 }
167         }
168
169         switch resp.Header.Get("Content-Type") {
170         case "application/font-woff", "application/font-sfnt":
171                 // Those are deprecated types
172                 fallthrough
173         case "font/otf", "font/ttf", "font/woff", "font/woff2":
174                 http.NotFound(w, req)
175                 sinkOther <- fmt.Sprintf(
176                         "%s %s\t%d\tfonts are not allowed",
177                         req.Method,
178                         req.URL.String(),
179                         http.StatusNotFound,
180                 )
181                 resp.Body.Close()
182                 return
183         case "image/webp":
184                 if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
185                         // My Xombrero
186                         break
187                 }
188                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
189                 if err != nil {
190                         log.Fatalln(err)
191                 }
192                 defer tmpFd.Close()
193                 defer os.Remove(tmpFd.Name())
194                 defer resp.Body.Close()
195                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
196                         log.Printf("Error during %s: %+v\n", req.URL, err)
197                         http.Error(w, err.Error(), http.StatusBadGateway)
198                         return
199                 }
200                 tmpFd.Close()
201                 cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
202                 data, err := cmd.Output()
203                 if err != nil {
204                         http.Error(w, err.Error(), http.StatusBadGateway)
205                         return
206                 }
207                 w.Header().Add("Content-Type", "image/png")
208                 w.WriteHeader(http.StatusOK)
209                 w.Write(data)
210                 sinkOther <- fmt.Sprintf(
211                         "%s %s\t%d\tWebP transcoded to PNG",
212                         req.Method,
213                         req.URL.String(),
214                         http.StatusOK,
215                 )
216                 return
217         case "image/jxl":
218                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
219                 if err != nil {
220                         log.Fatalln(err)
221                 }
222                 defer tmpFd.Close()
223                 defer os.Remove(tmpFd.Name())
224                 defer resp.Body.Close()
225                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
226                         log.Printf("Error during %s: %+v\n", req.URL, err)
227                         http.Error(w, err.Error(), http.StatusBadGateway)
228                         return
229                 }
230                 tmpFd.Close()
231                 dstFn := tmpFd.Name() + ".png"
232                 cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
233                 err = cmd.Run()
234                 defer os.Remove(dstFn)
235                 if err != nil {
236                         http.Error(w, err.Error(), http.StatusBadGateway)
237                         return
238                 }
239                 data, err := ioutil.ReadFile(dstFn)
240                 if err != nil {
241                         http.Error(w, err.Error(), http.StatusBadGateway)
242                         return
243                 }
244                 w.Header().Add("Content-Type", "image/png")
245                 w.WriteHeader(http.StatusOK)
246                 w.Write(data)
247                 sinkOther <- fmt.Sprintf(
248                         "%s %s\t%d\tJPEG XL transcoded to PNG",
249                         req.Method,
250                         req.URL.String(),
251                         http.StatusOK,
252                 )
253                 return
254         }
255
256         if req.Method == http.MethodGet {
257                 var redirType string
258                 switch resp.StatusCode {
259                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
260                         redirType = "permanent"
261                         goto Redir
262                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
263                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
264                                 goto NoRedir
265                         }
266                         if _, ok := imageExts[filepath.Ext(req.URL.Path)]; ok {
267                                 goto NoRedir
268                         }
269                         redirType = "temporary"
270                 default:
271                         goto NoRedir
272                 }
273         Redir:
274                 resp.Body.Close()
275                 w.Header().Add("Content-Type", "text/html")
276                 w.WriteHeader(http.StatusOK)
277                 location := resp.Header.Get("Location")
278                 w.Write([]byte(
279                         fmt.Sprintf(
280                                 `<html><head><title>%d %s: %s redirection</title></head>
281 <body>Redirection to <a href="%s">%s</a></body></html>`,
282                                 resp.StatusCode, http.StatusText(resp.StatusCode),
283                                 redirType, location, location,
284                         )))
285                 sinkRedir <- fmt.Sprintf(
286                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
287                 )
288                 return
289         }
290
291 NoRedir:
292         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
293                 if v := resp.Header.Get(h); v != "" {
294                         w.Header().Add(h, v)
295                 }
296         }
297         w.WriteHeader(resp.StatusCode)
298         n, err := io.Copy(w, resp.Body)
299         if err != nil {
300                 log.Printf("Error during %s: %+v\n", req.URL, err)
301         }
302         resp.Body.Close()
303         msg := fmt.Sprintf(
304                 "%s %s\t%s\t%s\t%s",
305                 req.Method,
306                 req.URL.String(),
307                 resp.Status,
308                 resp.Header.Get("Content-Type"),
309                 humanize.IBytes(uint64(n)),
310         )
311         if resp.StatusCode == http.StatusOK {
312                 sinkOK <- msg
313         } else {
314                 sinkOther <- msg
315         }
316 }
317
318 type Handler struct{}
319
320 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
321         if req.Method != http.MethodConnect {
322                 roundTrip(w, req)
323                 return
324         }
325         hj, ok := w.(http.Hijacker)
326         if !ok {
327                 log.Fatalln("no hijacking")
328         }
329         conn, _, err := hj.Hijack()
330         if err != nil {
331                 log.Fatalln(err)
332         }
333         defer conn.Close()
334         conn.Write([]byte(fmt.Sprintf(
335                 "%s %d %s\r\n\r\n",
336                 req.Proto,
337                 http.StatusOK, http.StatusText(http.StatusOK),
338         )))
339         host := strings.Split(req.Host, ":")[0]
340         hostCertsM.Lock()
341         keypair, ok := hostCerts[host]
342         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
343                 keypair = newKeypair(host, caCert, caPrv)
344                 hostCerts[host] = keypair
345         }
346         hostCertsM.Unlock()
347         tlsConn := tls.Server(conn, &tls.Config{
348                 Certificates: []tls.Certificate{{
349                         Certificate: [][]byte{keypair.cert.Raw},
350                         PrivateKey:  keypair.prv,
351                 }},
352         })
353         if err = tlsConn.Handshake(); err != nil {
354                 log.Printf("TLS error %s: %+v\n", host, err)
355                 return
356         }
357         srv := http.Server{
358                 Handler:      &HTTPSHandler{host: req.Host},
359                 TLSNextProto: tlsNextProtoS,
360         }
361         err = srv.Serve(&SingleListener{conn: tlsConn})
362         if err != nil {
363                 if _, ok := err.(AlreadyAccepted); !ok {
364                         log.Printf("TLS serve error %s: %+v\n", host, err)
365                         return
366                 }
367         }
368 }
369
370 type HTTPSHandler struct {
371         host string
372 }
373
374 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
375         req.URL.Scheme = "https"
376         req.URL.Host = h.host
377         roundTrip(w, req)
378 }
379
380 func main() {
381         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
382         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
383         bind := flag.String("bind", "[::1]:8080", "Bind address")
384         certs = flag.String("certs", "certs", "Directory with pinned certificates")
385         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
386         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
387         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
388         flag.Parse()
389         log.SetFlags(log.Lshortfile)
390         fifoInit()
391
392         var err error
393         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
394         if err != nil {
395                 log.Fatalln(err)
396         }
397         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
398         if err != nil {
399                 log.Fatalln(err)
400         }
401
402         ln, err := net.Listen("tcp", *bind)
403         if err != nil {
404                 log.Fatalln(err)
405         }
406         srv := http.Server{
407                 Handler:      &Handler{},
408                 TLSNextProto: tlsNextProtoS,
409         }
410         log.Println("listening:", *bind)
411         if err := srv.Serve(ln); err != nil {
412                 log.Fatalln(err)
413         }
414 }