]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
TLS session resumption support
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "context"
21         "crypto"
22         "crypto/tls"
23         "crypto/x509"
24         "flag"
25         "fmt"
26         "io"
27         "io/ioutil"
28         "log"
29         "net"
30         "net/http"
31         "os"
32         "os/exec"
33         "strings"
34         "time"
35
36         "github.com/dustin/go-humanize"
37         "go.cypherpunks.ru/ucspi"
38 )
39
40 var (
41         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
42         caCert        *x509.Certificate
43         caPrv         crypto.PrivateKey
44         transport     = http.Transport{
45                 ForceAttemptHTTP2: false,
46                 TLSNextProto:      make(map[string]func(string, *tls.Conn) http.RoundTripper),
47                 DialTLSContext:    dialTLS,
48         }
49         sessionCache = tls.NewLRUClientSessionCache(1024)
50
51         CmdDWebP = "dwebp"
52         CmdDJXL  = "djxl"
53 )
54
55 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
56         host := strings.TrimSuffix(addr, ":443")
57         cfg := tls.Config{
58                 VerifyPeerCertificate: func(
59                         rawCerts [][]byte,
60                         verifiedChains [][]*x509.Certificate,
61                 ) error {
62                         return verifyCert(host, nil, rawCerts, verifiedChains)
63                 },
64                 ClientSessionCache: sessionCache,
65         }
66         conn, dialErr := tls.Dial(network, addr, &cfg)
67         if dialErr != nil {
68                 if _, ok := dialErr.(ErrRejected); ok {
69                         return nil, dialErr
70                 }
71                 cfg.InsecureSkipVerify = true
72                 cfg.VerifyPeerCertificate = func(
73                         rawCerts [][]byte,
74                         verifiedChains [][]*x509.Certificate,
75                 ) error {
76                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
77                 }
78                 var err error
79                 conn, err = tls.Dial(network, addr, &cfg)
80                 if err != nil {
81                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
82                         return nil, err
83                 }
84         }
85         connState := conn.ConnectionState()
86         msg := fmt.Sprintf(
87                 "%s\t%s %s\t%s",
88                 strings.TrimSuffix(addr, ":443"),
89                 ucspi.TLSVersion(connState.Version),
90                 tls.CipherSuiteName(connState.CipherSuite),
91                 spkiHash(connState.PeerCertificates[0]),
92         )
93         if connState.DidResume {
94                 msg += "\tresumed"
95         }
96         sinkTLS <- msg
97         return conn, nil
98 }
99
100 func roundTrip(w http.ResponseWriter, req *http.Request) {
101         if req.Method == http.MethodHead {
102                 http.Error(w, "go away", http.StatusMethodNotAllowed)
103                 return
104         }
105         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
106         host := strings.TrimSuffix(req.URL.Host, ":443")
107
108         for _, spy := range SpyDomains {
109                 if strings.HasSuffix(host, spy) {
110                         http.NotFound(w, req)
111                         sinkOther <- fmt.Sprintf(
112                                 "%s %s\t%d\tspy one",
113                                 req.Method,
114                                 req.URL.String(),
115                                 http.StatusNotFound,
116                         )
117                         return
118                 }
119         }
120
121         if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
122                 req.URL.Host = "old.reddit.com"
123                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
124                 return
125         }
126
127         resp, err := transport.RoundTrip(req)
128         if err != nil {
129                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
130                 w.WriteHeader(http.StatusBadGateway)
131                 w.Write([]byte(err.Error()))
132                 return
133         }
134
135         for k, vs := range resp.Header {
136                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
137                         continue
138                 }
139                 for _, v := range vs {
140                         w.Header().Add(k, v)
141                 }
142         }
143
144         switch resp.Header.Get("Content-Type") {
145         case "application/font-woff", "application/font-sfnt":
146                 // Those are deprecated types
147                 fallthrough
148         case "font/otf", "font/ttf", "font/woff", "font/woff2":
149                 http.NotFound(w, req)
150                 sinkOther <- fmt.Sprintf(
151                         "%s %s\t%d\tfonts are not allowed",
152                         req.Method,
153                         req.URL.String(),
154                         http.StatusNotFound,
155                 )
156                 resp.Body.Close()
157                 return
158         case "image/webp":
159                 if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
160                         // My Xombrero
161                         break
162                 }
163                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
164                 if err != nil {
165                         log.Fatalln(err)
166                 }
167                 defer tmpFd.Close()
168                 defer os.Remove(tmpFd.Name())
169                 defer resp.Body.Close()
170                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
171                         log.Printf("Error during %s: %+v\n", req.URL, err)
172                         http.Error(w, err.Error(), http.StatusBadGateway)
173                         return
174                 }
175                 tmpFd.Close()
176                 cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
177                 data, err := cmd.Output()
178                 if err != nil {
179                         http.Error(w, err.Error(), http.StatusBadGateway)
180                         return
181                 }
182                 w.Header().Add("Content-Type", "image/png")
183                 w.WriteHeader(http.StatusOK)
184                 w.Write(data)
185                 sinkOther <- fmt.Sprintf(
186                         "%s %s\t%d\tWebP transcoded to PNG",
187                         req.Method,
188                         req.URL.String(),
189                         http.StatusOK,
190                 )
191                 return
192         case "image/jxl":
193                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
194                 if err != nil {
195                         log.Fatalln(err)
196                 }
197                 defer tmpFd.Close()
198                 defer os.Remove(tmpFd.Name())
199                 defer resp.Body.Close()
200                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
201                         log.Printf("Error during %s: %+v\n", req.URL, err)
202                         http.Error(w, err.Error(), http.StatusBadGateway)
203                         return
204                 }
205                 tmpFd.Close()
206                 dstFn := tmpFd.Name() + ".png"
207                 cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
208                 err = cmd.Run()
209                 defer os.Remove(dstFn)
210                 if err != nil {
211                         http.Error(w, err.Error(), http.StatusBadGateway)
212                         return
213                 }
214                 data, err := ioutil.ReadFile(dstFn)
215                 if err != nil {
216                         http.Error(w, err.Error(), http.StatusBadGateway)
217                         return
218                 }
219                 w.Header().Add("Content-Type", "image/png")
220                 w.WriteHeader(http.StatusOK)
221                 w.Write(data)
222                 sinkOther <- fmt.Sprintf(
223                         "%s %s\t%d\tJPEG XL transcoded to PNG",
224                         req.Method,
225                         req.URL.String(),
226                         http.StatusOK,
227                 )
228                 return
229         }
230
231         if req.Method == http.MethodGet {
232                 var redirType string
233                 switch resp.StatusCode {
234                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
235                         redirType = "permanent"
236                         goto Redir
237                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
238                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
239                                 goto NoRedir
240                         }
241                         redirType = "temporary"
242                 default:
243                         goto NoRedir
244                 }
245         Redir:
246                 resp.Body.Close()
247                 w.Header().Add("Content-Type", "text/html")
248                 w.WriteHeader(http.StatusOK)
249                 location := resp.Header.Get("Location")
250                 w.Write([]byte(
251                         fmt.Sprintf(
252                                 `<html>
253 <head><title>%d %s: %s redirection</title></head>
254 <body>Redirection to <a href="%s">%s</a>
255 </body></html>`,
256                                 resp.StatusCode, http.StatusText(resp.StatusCode),
257                                 redirType, location, location,
258                         )))
259                 sinkRedir <- fmt.Sprintf(
260                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
261                 )
262                 return
263         }
264
265 NoRedir:
266         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
267                 if v := resp.Header.Get(h); v != "" {
268                         w.Header().Add(h, v)
269                 }
270         }
271         w.WriteHeader(resp.StatusCode)
272         n, err := io.Copy(w, resp.Body)
273         if err != nil {
274                 log.Printf("Error during %s: %+v\n", req.URL, err)
275         }
276         resp.Body.Close()
277         msg := fmt.Sprintf(
278                 "%s %s\t%s\t%s\t%s",
279                 req.Method,
280                 req.URL.String(),
281                 resp.Status,
282                 resp.Header.Get("Content-Type"),
283                 humanize.IBytes(uint64(n)),
284         )
285         if resp.StatusCode == http.StatusOK {
286                 sinkOK <- msg
287         } else {
288                 sinkOther <- msg
289         }
290 }
291
292 type Handler struct{}
293
294 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
295         if req.Method != http.MethodConnect {
296                 roundTrip(w, req)
297                 return
298         }
299         hj, ok := w.(http.Hijacker)
300         if !ok {
301                 log.Fatalln("no hijacking")
302         }
303         conn, _, err := hj.Hijack()
304         if err != nil {
305                 log.Fatalln(err)
306         }
307         defer conn.Close()
308         conn.Write([]byte(fmt.Sprintf(
309                 "%s %d %s\r\n\r\n",
310                 req.Proto,
311                 http.StatusOK, http.StatusText(http.StatusOK),
312         )))
313         host := strings.Split(req.Host, ":")[0]
314         hostCertsM.Lock()
315         keypair, ok := hostCerts[host]
316         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
317                 keypair = newKeypair(host, caCert, caPrv)
318                 hostCerts[host] = keypair
319         }
320         hostCertsM.Unlock()
321         tlsConn := tls.Server(conn, &tls.Config{
322                 Certificates: []tls.Certificate{{
323                         Certificate: [][]byte{keypair.cert.Raw},
324                         PrivateKey:  keypair.prv,
325                 }},
326         })
327         if err = tlsConn.Handshake(); err != nil {
328                 log.Printf("TLS error %s: %+v\n", host, err)
329                 return
330         }
331         srv := http.Server{
332                 Handler:      &HTTPSHandler{host: req.Host},
333                 TLSNextProto: tlsNextProtoS,
334         }
335         err = srv.Serve(&SingleListener{conn: tlsConn})
336         if err != nil {
337                 if _, ok := err.(AlreadyAccepted); !ok {
338                         log.Printf("TLS serve error %s: %+v\n", host, err)
339                         return
340                 }
341         }
342 }
343
344 type HTTPSHandler struct {
345         host string
346 }
347
348 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
349         req.URL.Scheme = "https"
350         req.URL.Host = h.host
351         roundTrip(w, req)
352 }
353
354 func main() {
355         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
356         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
357         bind := flag.String("bind", "[::1]:8080", "Bind address")
358         certs = flag.String("certs", "certs", "Directory with pinned certificates")
359         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
360         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
361         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
362         flag.Parse()
363         log.SetFlags(log.Lshortfile)
364         fifoInit()
365
366         var err error
367         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
368         if err != nil {
369                 log.Fatalln(err)
370         }
371         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
372         if err != nil {
373                 log.Fatalln(err)
374         }
375
376         ln, err := net.Listen("tcp", *bind)
377         if err != nil {
378                 log.Fatalln(err)
379         }
380         srv := http.Server{
381                 Handler:      &Handler{},
382                 TLSNextProto: tlsNextProtoS,
383         }
384         log.Println("listening:", *bind)
385         if err := srv.Serve(ln); err != nil {
386                 log.Fatalln(err)
387         }
388 }