]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
16fae3d6c0742f5ab38723ce1dfa4ae6f704cf35
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "context"
21         "crypto"
22         "crypto/sha256"
23         "crypto/tls"
24         "crypto/x509"
25         "encoding/hex"
26         "flag"
27         "fmt"
28         "io"
29         "log"
30         "net"
31         "net/http"
32         "strings"
33         "sync"
34         "time"
35
36         "github.com/dustin/go-humanize"
37         "go.cypherpunks.ru/ucspi"
38 )
39
40 var (
41         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
42         tlsNextProtoC = make(map[string]func(string, *tls.Conn) http.RoundTripper)
43         caCert        *x509.Certificate
44         caPrv         crypto.PrivateKey
45         certs         *string
46         dnsSrv        *string
47         transport     = http.Transport{
48                 ForceAttemptHTTP2:   false,
49                 DisableKeepAlives:   true,
50                 MaxIdleConnsPerHost: 2,
51                 TLSNextProto:        tlsNextProtoC,
52                 DialTLSContext:      dialTLS,
53         }
54
55         accepted  = make(map[string]string)
56         acceptedM sync.RWMutex
57         rejected  = make(map[string]string)
58         rejectedM sync.RWMutex
59 )
60
61 func spkiHash(cert *x509.Certificate) string {
62         hsh := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
63         return hex.EncodeToString(hsh[:])
64 }
65
66 func acceptedAdd(addr, h string) {
67         acceptedM.Lock()
68         accepted[addr] = h
69         acceptedM.Unlock()
70 }
71
72 func rejectedAdd(addr, h string) {
73         rejectedM.Lock()
74         rejected[addr] = h
75         rejectedM.Unlock()
76 }
77
78 type ErrRejected struct {
79         addr string
80 }
81
82 func (err ErrRejected) Error() string { return err.addr + " was rejected" }
83
84 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
85         host := strings.TrimSuffix(addr, ":443")
86         cfg := tls.Config{
87                 VerifyPeerCertificate: func(
88                         rawCerts [][]byte,
89                         verifiedChains [][]*x509.Certificate,
90                 ) error {
91                         return verifyCert(host, nil, rawCerts, verifiedChains)
92                 },
93         }
94         conn, dialErr := tls.Dial(network, addr, &cfg)
95         if dialErr != nil {
96                 if _, ok := dialErr.(ErrRejected); ok {
97                         return nil, dialErr
98                 }
99                 cfg.InsecureSkipVerify = true
100                 cfg.VerifyPeerCertificate = func(
101                         rawCerts [][]byte,
102                         verifiedChains [][]*x509.Certificate,
103                 ) error {
104                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
105                 }
106                 var err error
107                 conn, err = tls.Dial(network, addr, &cfg)
108                 if err != nil {
109                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
110                         return nil, err
111                 }
112         }
113         connState := conn.ConnectionState()
114         sinkTLS <- fmt.Sprintf(
115                 "%s\t%s %s\t%s",
116                 strings.TrimSuffix(addr, ":443"),
117                 ucspi.TLSVersion(connState.Version),
118                 tls.CipherSuiteName(connState.CipherSuite),
119                 spkiHash(connState.PeerCertificates[0]),
120         )
121         return conn, nil
122 }
123
124 func roundTrip(w http.ResponseWriter, req *http.Request) {
125         if req.Method == http.MethodHead {
126                 http.Error(w, "go away", http.StatusMethodNotAllowed)
127                 return
128         }
129         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
130         host := strings.TrimSuffix(req.URL.Host, ":443")
131         for _, spy := range SpyDomains {
132                 if strings.HasSuffix(host, spy) {
133                         http.NotFound(w, req)
134                         sinkOther <- fmt.Sprintf(
135                                 "%s %s\t%d\tspy one",
136                                 req.Method,
137                                 req.URL.String(),
138                                 http.StatusNotFound,
139                         )
140                         return
141                 }
142         }
143         if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
144                 req.URL.Host = "old.reddit.com"
145                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
146                 return
147         }
148         resp, err := transport.RoundTrip(req)
149         if err != nil {
150                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
151                 w.WriteHeader(http.StatusBadGateway)
152                 w.Write([]byte(err.Error()))
153                 return
154         }
155         contentType := resp.Header.Get("Content-Type")
156         switch contentType {
157         case "application/font-woff", "application/font-sfnt":
158                 // Those are deprecated types
159                 fallthrough
160         case "font/otf", "font/ttf", "font/woff", "font/woff2":
161                 http.NotFound(w, req)
162                 sinkOther <- fmt.Sprintf(
163                         "%s %s\t%d\tfonts are not allowed",
164                         req.Method,
165                         req.URL.String(),
166                         http.StatusNotFound,
167                 )
168                 resp.Body.Close()
169                 return
170         }
171         for k, vs := range resp.Header {
172                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
173                         continue
174                 }
175                 for _, v := range vs {
176                         w.Header().Add(k, v)
177                 }
178         }
179         if req.Method == http.MethodGet {
180                 var redirType string
181                 switch resp.StatusCode {
182                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
183                         redirType = "permanent"
184                         goto Redir
185                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
186                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
187                                 goto NoRedir
188                         }
189                         redirType = "temporary"
190                 default:
191                         goto NoRedir
192                 }
193         Redir:
194                 resp.Body.Close()
195                 w.Header().Add("Content-Type", "text/html")
196                 w.WriteHeader(http.StatusOK)
197                 location := resp.Header.Get("Location")
198                 w.Write([]byte(
199                         fmt.Sprintf(
200                                 `<html>
201 <head><title>%d %s: %s redirection</title></head>
202 <body>Redirection to <a href="%s">%s</a>
203 </body></html>`,
204                                 resp.StatusCode, http.StatusText(resp.StatusCode),
205                                 redirType, location, location,
206                         )))
207                 sinkRedir <- fmt.Sprintf(
208                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
209                 )
210                 return
211         }
212 NoRedir:
213         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
214                 if v := resp.Header.Get(h); v != "" {
215                         w.Header().Add(h, v)
216                 }
217         }
218         w.WriteHeader(resp.StatusCode)
219         n, err := io.Copy(w, resp.Body)
220         if err != nil {
221                 log.Printf("Error during %s: %+v\n", req.URL, err)
222         }
223         resp.Body.Close()
224         msg := fmt.Sprintf(
225                 "%s %s\t%s\t%s\t%s",
226                 req.Method,
227                 req.URL.String(),
228                 resp.Status,
229                 resp.Header.Get("Content-Type"),
230                 humanize.IBytes(uint64(n)),
231         )
232         if resp.StatusCode == http.StatusOK {
233                 sinkOK <- msg
234         } else {
235                 sinkOther <- msg
236         }
237 }
238
239 type Handler struct{}
240
241 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
242         if req.Method != http.MethodConnect {
243                 roundTrip(w, req)
244                 return
245         }
246         hj, ok := w.(http.Hijacker)
247         if !ok {
248                 log.Fatalln("no hijacking")
249         }
250         conn, _, err := hj.Hijack()
251         if err != nil {
252                 log.Fatalln(err)
253         }
254         defer conn.Close()
255         conn.Write([]byte(fmt.Sprintf(
256                 "%s %d %s\r\n\r\n",
257                 req.Proto,
258                 http.StatusOK, http.StatusText(http.StatusOK),
259         )))
260         host := strings.Split(req.Host, ":")[0]
261         hostCertsM.Lock()
262         keypair, ok := hostCerts[host]
263         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
264                 keypair = newKeypair(host, caCert, caPrv)
265                 hostCerts[host] = keypair
266         }
267         hostCertsM.Unlock()
268         tlsConn := tls.Server(conn, &tls.Config{
269                 Certificates: []tls.Certificate{{
270                         Certificate: [][]byte{keypair.cert.Raw},
271                         PrivateKey:  keypair.prv,
272                 }},
273         })
274         if err = tlsConn.Handshake(); err != nil {
275                 log.Printf("TLS error %s: %+v\n", host, err)
276                 return
277         }
278         srv := http.Server{
279                 Handler:      &HTTPSHandler{host: req.Host},
280                 TLSNextProto: tlsNextProtoS,
281         }
282         err = srv.Serve(&SingleListener{conn: tlsConn})
283         if err != nil {
284                 if _, ok := err.(AlreadyAccepted); !ok {
285                         log.Printf("TLS serve error %s: %+v\n", host, err)
286                         return
287                 }
288         }
289 }
290
291 type HTTPSHandler struct {
292         host string
293 }
294
295 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
296         req.URL.Scheme = "https"
297         req.URL.Host = h.host
298         roundTrip(w, req)
299 }
300
301 func main() {
302         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
303         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
304         bind := flag.String("bind", "[::1]:8080", "Bind address")
305         certs = flag.String("certs", "certs", "Directory with pinned certificates")
306         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
307         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
308         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
309         flag.Parse()
310         log.SetFlags(log.Lshortfile)
311         fifoInit()
312
313         var err error
314         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
315         if err != nil {
316                 log.Fatalln(err)
317         }
318         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
319         if err != nil {
320                 log.Fatalln(err)
321         }
322
323         ln, err := net.Listen("tcp", *bind)
324         if err != nil {
325                 log.Fatalln(err)
326         }
327         srv := http.Server{
328                 Handler:      &Handler{},
329                 TLSNextProto: tlsNextProtoS,
330         }
331         srv.SetKeepAlivesEnabled(false)
332         log.Println("listening:", *bind)
333         if err := srv.Serve(ln); err != nil {
334                 log.Fatalln(err)
335         }
336 }