]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
WebP/JPEG-XL transcoding
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "context"
21         "crypto"
22         "crypto/sha256"
23         "crypto/tls"
24         "crypto/x509"
25         "encoding/hex"
26         "flag"
27         "fmt"
28         "io"
29         "io/ioutil"
30         "log"
31         "net"
32         "net/http"
33         "os"
34         "os/exec"
35         "strings"
36         "sync"
37         "time"
38
39         "github.com/dustin/go-humanize"
40         "go.cypherpunks.ru/ucspi"
41 )
42
43 var (
44         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
45         tlsNextProtoC = make(map[string]func(string, *tls.Conn) http.RoundTripper)
46         caCert        *x509.Certificate
47         caPrv         crypto.PrivateKey
48         certs         *string
49         dnsSrv        *string
50         transport     = http.Transport{
51                 ForceAttemptHTTP2:   false,
52                 DisableKeepAlives:   true,
53                 MaxIdleConnsPerHost: 2,
54                 TLSNextProto:        tlsNextProtoC,
55                 DialTLSContext:      dialTLS,
56         }
57
58         accepted  = make(map[string]string)
59         acceptedM sync.RWMutex
60         rejected  = make(map[string]string)
61         rejectedM sync.RWMutex
62
63         CmdDWebP = "dwebp"
64         CmdDJXL  = "djxl"
65 )
66
67 func spkiHash(cert *x509.Certificate) string {
68         hsh := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
69         return hex.EncodeToString(hsh[:])
70 }
71
72 func acceptedAdd(addr, h string) {
73         acceptedM.Lock()
74         accepted[addr] = h
75         acceptedM.Unlock()
76 }
77
78 func rejectedAdd(addr, h string) {
79         rejectedM.Lock()
80         rejected[addr] = h
81         rejectedM.Unlock()
82 }
83
84 type ErrRejected struct {
85         addr string
86 }
87
88 func (err ErrRejected) Error() string { return err.addr + " was rejected" }
89
90 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
91         host := strings.TrimSuffix(addr, ":443")
92         cfg := tls.Config{
93                 VerifyPeerCertificate: func(
94                         rawCerts [][]byte,
95                         verifiedChains [][]*x509.Certificate,
96                 ) error {
97                         return verifyCert(host, nil, rawCerts, verifiedChains)
98                 },
99         }
100         conn, dialErr := tls.Dial(network, addr, &cfg)
101         if dialErr != nil {
102                 if _, ok := dialErr.(ErrRejected); ok {
103                         return nil, dialErr
104                 }
105                 cfg.InsecureSkipVerify = true
106                 cfg.VerifyPeerCertificate = func(
107                         rawCerts [][]byte,
108                         verifiedChains [][]*x509.Certificate,
109                 ) error {
110                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
111                 }
112                 var err error
113                 conn, err = tls.Dial(network, addr, &cfg)
114                 if err != nil {
115                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
116                         return nil, err
117                 }
118         }
119         connState := conn.ConnectionState()
120         sinkTLS <- fmt.Sprintf(
121                 "%s\t%s %s\t%s",
122                 strings.TrimSuffix(addr, ":443"),
123                 ucspi.TLSVersion(connState.Version),
124                 tls.CipherSuiteName(connState.CipherSuite),
125                 spkiHash(connState.PeerCertificates[0]),
126         )
127         return conn, nil
128 }
129
130 func roundTrip(w http.ResponseWriter, req *http.Request) {
131         if req.Method == http.MethodHead {
132                 http.Error(w, "go away", http.StatusMethodNotAllowed)
133                 return
134         }
135         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
136         host := strings.TrimSuffix(req.URL.Host, ":443")
137
138         for _, spy := range SpyDomains {
139                 if strings.HasSuffix(host, spy) {
140                         http.NotFound(w, req)
141                         sinkOther <- fmt.Sprintf(
142                                 "%s %s\t%d\tspy one",
143                                 req.Method,
144                                 req.URL.String(),
145                                 http.StatusNotFound,
146                         )
147                         return
148                 }
149         }
150
151         if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
152                 req.URL.Host = "old.reddit.com"
153                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
154                 return
155         }
156
157         resp, err := transport.RoundTrip(req)
158         if err != nil {
159                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
160                 w.WriteHeader(http.StatusBadGateway)
161                 w.Write([]byte(err.Error()))
162                 return
163         }
164
165         for k, vs := range resp.Header {
166                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
167                         continue
168                 }
169                 for _, v := range vs {
170                         w.Header().Add(k, v)
171                 }
172         }
173
174         switch resp.Header.Get("Content-Type") {
175         case "application/font-woff", "application/font-sfnt":
176                 // Those are deprecated types
177                 fallthrough
178         case "font/otf", "font/ttf", "font/woff", "font/woff2":
179                 http.NotFound(w, req)
180                 sinkOther <- fmt.Sprintf(
181                         "%s %s\t%d\tfonts are not allowed",
182                         req.Method,
183                         req.URL.String(),
184                         http.StatusNotFound,
185                 )
186                 resp.Body.Close()
187                 return
188         case "image/webp":
189                 if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
190                         // My Xombrero
191                         break
192                 }
193                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
194                 if err != nil {
195                         log.Fatalln(err)
196                 }
197                 defer tmpFd.Close()
198                 defer os.Remove(tmpFd.Name())
199                 defer resp.Body.Close()
200                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
201                         log.Printf("Error during %s: %+v\n", req.URL, err)
202                         http.Error(w, err.Error(), http.StatusBadGateway)
203                         return
204                 }
205                 tmpFd.Close()
206                 cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
207                 data, err := cmd.Output()
208                 if err != nil {
209                         http.Error(w, err.Error(), http.StatusBadGateway)
210                         return
211                 }
212                 w.Header().Add("Content-Type", "image/png")
213                 w.WriteHeader(http.StatusOK)
214                 w.Write(data)
215                 sinkOther <- fmt.Sprintf(
216                         "%s %s\t%d\tWebP transcoded to PNG",
217                         req.Method,
218                         req.URL.String(),
219                         http.StatusOK,
220                 )
221                 return
222         case "image/jxl":
223                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
224                 if err != nil {
225                         log.Fatalln(err)
226                 }
227                 defer tmpFd.Close()
228                 defer os.Remove(tmpFd.Name())
229                 defer resp.Body.Close()
230                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
231                         log.Printf("Error during %s: %+v\n", req.URL, err)
232                         http.Error(w, err.Error(), http.StatusBadGateway)
233                         return
234                 }
235                 tmpFd.Close()
236                 dstFn := tmpFd.Name() + ".png"
237                 cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
238                 err = cmd.Run()
239                 defer os.Remove(dstFn)
240                 if err != nil {
241                         http.Error(w, err.Error(), http.StatusBadGateway)
242                         return
243                 }
244                 data, err := ioutil.ReadFile(dstFn)
245                 if err != nil {
246                         http.Error(w, err.Error(), http.StatusBadGateway)
247                         return
248                 }
249                 w.Header().Add("Content-Type", "image/png")
250                 w.WriteHeader(http.StatusOK)
251                 w.Write(data)
252                 sinkOther <- fmt.Sprintf(
253                         "%s %s\t%d\tJPEG XL transcoded to PNG",
254                         req.Method,
255                         req.URL.String(),
256                         http.StatusOK,
257                 )
258                 return
259         }
260
261         if req.Method == http.MethodGet {
262                 var redirType string
263                 switch resp.StatusCode {
264                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
265                         redirType = "permanent"
266                         goto Redir
267                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
268                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
269                                 goto NoRedir
270                         }
271                         redirType = "temporary"
272                 default:
273                         goto NoRedir
274                 }
275         Redir:
276                 resp.Body.Close()
277                 w.Header().Add("Content-Type", "text/html")
278                 w.WriteHeader(http.StatusOK)
279                 location := resp.Header.Get("Location")
280                 w.Write([]byte(
281                         fmt.Sprintf(
282                                 `<html>
283 <head><title>%d %s: %s redirection</title></head>
284 <body>Redirection to <a href="%s">%s</a>
285 </body></html>`,
286                                 resp.StatusCode, http.StatusText(resp.StatusCode),
287                                 redirType, location, location,
288                         )))
289                 sinkRedir <- fmt.Sprintf(
290                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
291                 )
292                 return
293         }
294
295 NoRedir:
296         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
297                 if v := resp.Header.Get(h); v != "" {
298                         w.Header().Add(h, v)
299                 }
300         }
301         w.WriteHeader(resp.StatusCode)
302         n, err := io.Copy(w, resp.Body)
303         if err != nil {
304                 log.Printf("Error during %s: %+v\n", req.URL, err)
305         }
306         resp.Body.Close()
307         msg := fmt.Sprintf(
308                 "%s %s\t%s\t%s\t%s",
309                 req.Method,
310                 req.URL.String(),
311                 resp.Status,
312                 resp.Header.Get("Content-Type"),
313                 humanize.IBytes(uint64(n)),
314         )
315         if resp.StatusCode == http.StatusOK {
316                 sinkOK <- msg
317         } else {
318                 sinkOther <- msg
319         }
320 }
321
322 type Handler struct{}
323
324 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
325         if req.Method != http.MethodConnect {
326                 roundTrip(w, req)
327                 return
328         }
329         hj, ok := w.(http.Hijacker)
330         if !ok {
331                 log.Fatalln("no hijacking")
332         }
333         conn, _, err := hj.Hijack()
334         if err != nil {
335                 log.Fatalln(err)
336         }
337         defer conn.Close()
338         conn.Write([]byte(fmt.Sprintf(
339                 "%s %d %s\r\n\r\n",
340                 req.Proto,
341                 http.StatusOK, http.StatusText(http.StatusOK),
342         )))
343         host := strings.Split(req.Host, ":")[0]
344         hostCertsM.Lock()
345         keypair, ok := hostCerts[host]
346         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
347                 keypair = newKeypair(host, caCert, caPrv)
348                 hostCerts[host] = keypair
349         }
350         hostCertsM.Unlock()
351         tlsConn := tls.Server(conn, &tls.Config{
352                 Certificates: []tls.Certificate{{
353                         Certificate: [][]byte{keypair.cert.Raw},
354                         PrivateKey:  keypair.prv,
355                 }},
356         })
357         if err = tlsConn.Handshake(); err != nil {
358                 log.Printf("TLS error %s: %+v\n", host, err)
359                 return
360         }
361         srv := http.Server{
362                 Handler:      &HTTPSHandler{host: req.Host},
363                 TLSNextProto: tlsNextProtoS,
364         }
365         err = srv.Serve(&SingleListener{conn: tlsConn})
366         if err != nil {
367                 if _, ok := err.(AlreadyAccepted); !ok {
368                         log.Printf("TLS serve error %s: %+v\n", host, err)
369                         return
370                 }
371         }
372 }
373
374 type HTTPSHandler struct {
375         host string
376 }
377
378 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
379         req.URL.Scheme = "https"
380         req.URL.Host = h.host
381         roundTrip(w, req)
382 }
383
384 func main() {
385         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
386         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
387         bind := flag.String("bind", "[::1]:8080", "Bind address")
388         certs = flag.String("certs", "certs", "Directory with pinned certificates")
389         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
390         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
391         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
392         flag.Parse()
393         log.SetFlags(log.Lshortfile)
394         fifoInit()
395
396         var err error
397         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
398         if err != nil {
399                 log.Fatalln(err)
400         }
401         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
402         if err != nil {
403                 log.Fatalln(err)
404         }
405
406         ln, err := net.Listen("tcp", *bind)
407         if err != nil {
408                 log.Fatalln(err)
409         }
410         srv := http.Server{
411                 Handler:      &Handler{},
412                 TLSNextProto: tlsNextProtoS,
413         }
414         srv.SetKeepAlivesEnabled(false)
415         log.Println("listening:", *bind)
416         if err := srv.Serve(ln); err != nil {
417                 log.Fatalln(err)
418         }
419 }