]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
Tiny refactor, no keep-alive restrictions
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "context"
21         "crypto"
22         "crypto/tls"
23         "crypto/x509"
24         "flag"
25         "fmt"
26         "io"
27         "io/ioutil"
28         "log"
29         "net"
30         "net/http"
31         "os"
32         "os/exec"
33         "strings"
34         "time"
35
36         "github.com/dustin/go-humanize"
37         "go.cypherpunks.ru/ucspi"
38 )
39
40 var (
41         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
42         caCert        *x509.Certificate
43         caPrv         crypto.PrivateKey
44         transport     = http.Transport{
45                 ForceAttemptHTTP2: false,
46                 TLSNextProto:      make(map[string]func(string, *tls.Conn) http.RoundTripper),
47                 DialTLSContext:    dialTLS,
48         }
49
50         CmdDWebP = "dwebp"
51         CmdDJXL  = "djxl"
52 )
53
54 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
55         host := strings.TrimSuffix(addr, ":443")
56         cfg := tls.Config{
57                 VerifyPeerCertificate: func(
58                         rawCerts [][]byte,
59                         verifiedChains [][]*x509.Certificate,
60                 ) error {
61                         return verifyCert(host, nil, rawCerts, verifiedChains)
62                 },
63         }
64         conn, dialErr := tls.Dial(network, addr, &cfg)
65         if dialErr != nil {
66                 if _, ok := dialErr.(ErrRejected); ok {
67                         return nil, dialErr
68                 }
69                 cfg.InsecureSkipVerify = true
70                 cfg.VerifyPeerCertificate = func(
71                         rawCerts [][]byte,
72                         verifiedChains [][]*x509.Certificate,
73                 ) error {
74                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
75                 }
76                 var err error
77                 conn, err = tls.Dial(network, addr, &cfg)
78                 if err != nil {
79                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
80                         return nil, err
81                 }
82         }
83         connState := conn.ConnectionState()
84         sinkTLS <- fmt.Sprintf(
85                 "%s\t%s %s\t%s",
86                 strings.TrimSuffix(addr, ":443"),
87                 ucspi.TLSVersion(connState.Version),
88                 tls.CipherSuiteName(connState.CipherSuite),
89                 spkiHash(connState.PeerCertificates[0]),
90         )
91         return conn, nil
92 }
93
94 func roundTrip(w http.ResponseWriter, req *http.Request) {
95         if req.Method == http.MethodHead {
96                 http.Error(w, "go away", http.StatusMethodNotAllowed)
97                 return
98         }
99         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
100         host := strings.TrimSuffix(req.URL.Host, ":443")
101
102         for _, spy := range SpyDomains {
103                 if strings.HasSuffix(host, spy) {
104                         http.NotFound(w, req)
105                         sinkOther <- fmt.Sprintf(
106                                 "%s %s\t%d\tspy one",
107                                 req.Method,
108                                 req.URL.String(),
109                                 http.StatusNotFound,
110                         )
111                         return
112                 }
113         }
114
115         if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
116                 req.URL.Host = "old.reddit.com"
117                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
118                 return
119         }
120
121         resp, err := transport.RoundTrip(req)
122         if err != nil {
123                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
124                 w.WriteHeader(http.StatusBadGateway)
125                 w.Write([]byte(err.Error()))
126                 return
127         }
128
129         for k, vs := range resp.Header {
130                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
131                         continue
132                 }
133                 for _, v := range vs {
134                         w.Header().Add(k, v)
135                 }
136         }
137
138         switch resp.Header.Get("Content-Type") {
139         case "application/font-woff", "application/font-sfnt":
140                 // Those are deprecated types
141                 fallthrough
142         case "font/otf", "font/ttf", "font/woff", "font/woff2":
143                 http.NotFound(w, req)
144                 sinkOther <- fmt.Sprintf(
145                         "%s %s\t%d\tfonts are not allowed",
146                         req.Method,
147                         req.URL.String(),
148                         http.StatusNotFound,
149                 )
150                 resp.Body.Close()
151                 return
152         case "image/webp":
153                 if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
154                         // My Xombrero
155                         break
156                 }
157                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
158                 if err != nil {
159                         log.Fatalln(err)
160                 }
161                 defer tmpFd.Close()
162                 defer os.Remove(tmpFd.Name())
163                 defer resp.Body.Close()
164                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
165                         log.Printf("Error during %s: %+v\n", req.URL, err)
166                         http.Error(w, err.Error(), http.StatusBadGateway)
167                         return
168                 }
169                 tmpFd.Close()
170                 cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
171                 data, err := cmd.Output()
172                 if err != nil {
173                         http.Error(w, err.Error(), http.StatusBadGateway)
174                         return
175                 }
176                 w.Header().Add("Content-Type", "image/png")
177                 w.WriteHeader(http.StatusOK)
178                 w.Write(data)
179                 sinkOther <- fmt.Sprintf(
180                         "%s %s\t%d\tWebP transcoded to PNG",
181                         req.Method,
182                         req.URL.String(),
183                         http.StatusOK,
184                 )
185                 return
186         case "image/jxl":
187                 tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
188                 if err != nil {
189                         log.Fatalln(err)
190                 }
191                 defer tmpFd.Close()
192                 defer os.Remove(tmpFd.Name())
193                 defer resp.Body.Close()
194                 if _, err = io.Copy(tmpFd, resp.Body); err != nil {
195                         log.Printf("Error during %s: %+v\n", req.URL, err)
196                         http.Error(w, err.Error(), http.StatusBadGateway)
197                         return
198                 }
199                 tmpFd.Close()
200                 dstFn := tmpFd.Name() + ".png"
201                 cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
202                 err = cmd.Run()
203                 defer os.Remove(dstFn)
204                 if err != nil {
205                         http.Error(w, err.Error(), http.StatusBadGateway)
206                         return
207                 }
208                 data, err := ioutil.ReadFile(dstFn)
209                 if err != nil {
210                         http.Error(w, err.Error(), http.StatusBadGateway)
211                         return
212                 }
213                 w.Header().Add("Content-Type", "image/png")
214                 w.WriteHeader(http.StatusOK)
215                 w.Write(data)
216                 sinkOther <- fmt.Sprintf(
217                         "%s %s\t%d\tJPEG XL transcoded to PNG",
218                         req.Method,
219                         req.URL.String(),
220                         http.StatusOK,
221                 )
222                 return
223         }
224
225         if req.Method == http.MethodGet {
226                 var redirType string
227                 switch resp.StatusCode {
228                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
229                         redirType = "permanent"
230                         goto Redir
231                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
232                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
233                                 goto NoRedir
234                         }
235                         redirType = "temporary"
236                 default:
237                         goto NoRedir
238                 }
239         Redir:
240                 resp.Body.Close()
241                 w.Header().Add("Content-Type", "text/html")
242                 w.WriteHeader(http.StatusOK)
243                 location := resp.Header.Get("Location")
244                 w.Write([]byte(
245                         fmt.Sprintf(
246                                 `<html>
247 <head><title>%d %s: %s redirection</title></head>
248 <body>Redirection to <a href="%s">%s</a>
249 </body></html>`,
250                                 resp.StatusCode, http.StatusText(resp.StatusCode),
251                                 redirType, location, location,
252                         )))
253                 sinkRedir <- fmt.Sprintf(
254                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
255                 )
256                 return
257         }
258
259 NoRedir:
260         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
261                 if v := resp.Header.Get(h); v != "" {
262                         w.Header().Add(h, v)
263                 }
264         }
265         w.WriteHeader(resp.StatusCode)
266         n, err := io.Copy(w, resp.Body)
267         if err != nil {
268                 log.Printf("Error during %s: %+v\n", req.URL, err)
269         }
270         resp.Body.Close()
271         msg := fmt.Sprintf(
272                 "%s %s\t%s\t%s\t%s",
273                 req.Method,
274                 req.URL.String(),
275                 resp.Status,
276                 resp.Header.Get("Content-Type"),
277                 humanize.IBytes(uint64(n)),
278         )
279         if resp.StatusCode == http.StatusOK {
280                 sinkOK <- msg
281         } else {
282                 sinkOther <- msg
283         }
284 }
285
286 type Handler struct{}
287
288 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
289         if req.Method != http.MethodConnect {
290                 roundTrip(w, req)
291                 return
292         }
293         hj, ok := w.(http.Hijacker)
294         if !ok {
295                 log.Fatalln("no hijacking")
296         }
297         conn, _, err := hj.Hijack()
298         if err != nil {
299                 log.Fatalln(err)
300         }
301         defer conn.Close()
302         conn.Write([]byte(fmt.Sprintf(
303                 "%s %d %s\r\n\r\n",
304                 req.Proto,
305                 http.StatusOK, http.StatusText(http.StatusOK),
306         )))
307         host := strings.Split(req.Host, ":")[0]
308         hostCertsM.Lock()
309         keypair, ok := hostCerts[host]
310         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
311                 keypair = newKeypair(host, caCert, caPrv)
312                 hostCerts[host] = keypair
313         }
314         hostCertsM.Unlock()
315         tlsConn := tls.Server(conn, &tls.Config{
316                 Certificates: []tls.Certificate{{
317                         Certificate: [][]byte{keypair.cert.Raw},
318                         PrivateKey:  keypair.prv,
319                 }},
320         })
321         if err = tlsConn.Handshake(); err != nil {
322                 log.Printf("TLS error %s: %+v\n", host, err)
323                 return
324         }
325         srv := http.Server{
326                 Handler:      &HTTPSHandler{host: req.Host},
327                 TLSNextProto: tlsNextProtoS,
328         }
329         err = srv.Serve(&SingleListener{conn: tlsConn})
330         if err != nil {
331                 if _, ok := err.(AlreadyAccepted); !ok {
332                         log.Printf("TLS serve error %s: %+v\n", host, err)
333                         return
334                 }
335         }
336 }
337
338 type HTTPSHandler struct {
339         host string
340 }
341
342 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
343         req.URL.Scheme = "https"
344         req.URL.Host = h.host
345         roundTrip(w, req)
346 }
347
348 func main() {
349         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
350         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
351         bind := flag.String("bind", "[::1]:8080", "Bind address")
352         certs = flag.String("certs", "certs", "Directory with pinned certificates")
353         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
354         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
355         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
356         flag.Parse()
357         log.SetFlags(log.Lshortfile)
358         fifoInit()
359
360         var err error
361         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
362         if err != nil {
363                 log.Fatalln(err)
364         }
365         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
366         if err != nil {
367                 log.Fatalln(err)
368         }
369
370         ln, err := net.Listen("tcp", *bind)
371         if err != nil {
372                 log.Fatalln(err)
373         }
374         srv := http.Server{
375                 Handler:      &Handler{},
376                 TLSNextProto: tlsNextProtoS,
377         }
378         log.Println("listening:", *bind)
379         if err := srv.Serve(ln); err != nil {
380                 log.Fatalln(err)
381         }
382 }