]> Sergey Matveev's repositories - tofuproxy.git/blob - main.go
Initial commit
[tofuproxy.git] / main.go
1 /*
2 Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
3
4 This program is free software: you can redistribute it and/or modify
5 it under the terms of the GNU General Public License as published by
6 the Free Software Foundation, version 3 of the License.
7
8 This program is distributed in the hope that it will be useful,
9 but WITHOUT ANY WARRANTY; without even the implied warranty of
10 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 GNU General Public License for more details.
12
13 You should have received a copy of the GNU General Public License
14 along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 */
16
17 package main
18
19 import (
20         "bytes"
21         "context"
22         "crypto"
23         "crypto/sha256"
24         "crypto/tls"
25         "crypto/x509"
26         "encoding/hex"
27         "flag"
28         "fmt"
29         "io"
30         "log"
31         "net"
32         "net/http"
33         "os/exec"
34         "strings"
35         "sync"
36         "time"
37
38         "github.com/dustin/go-humanize"
39         "go.cypherpunks.ru/ucspi"
40 )
41
42 var (
43         tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
44         tlsNextProtoC = make(map[string]func(string, *tls.Conn) http.RoundTripper)
45         caCert        *x509.Certificate
46         caPrv         crypto.PrivateKey
47         certs         *string
48         dnsSrv        *string
49         transport     = http.Transport{
50                 ForceAttemptHTTP2:   false,
51                 DisableKeepAlives:   true,
52                 MaxIdleConnsPerHost: 2,
53                 TLSNextProto:        tlsNextProtoC,
54                 DialTLSContext:      dialTLS,
55         }
56
57         accepted  = make(map[string]string)
58         acceptedM sync.RWMutex
59         rejected  = make(map[string]string)
60         rejectedM sync.RWMutex
61 )
62
63 func spkiHash(cert *x509.Certificate) string {
64         hsh := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
65         return hex.EncodeToString(hsh[:])
66 }
67
68 func certInfo(certRaw []byte) string {
69         cmd := exec.Command("certtool", "--certificate-info", "--inder")
70         cmd.Stdin = bytes.NewReader(certRaw)
71         out, err := cmd.Output()
72         if err == nil {
73                 return string(out)
74         }
75         return err.Error()
76 }
77
78 func acceptedAdd(addr, h string) {
79         acceptedM.Lock()
80         accepted[addr] = h
81         acceptedM.Unlock()
82 }
83
84 func rejectedAdd(addr, h string) {
85         rejectedM.Lock()
86         rejected[addr] = h
87         rejectedM.Unlock()
88 }
89
90 type ErrRejected struct {
91         addr string
92 }
93
94 func (err ErrRejected) Error() string { return err.addr + " was rejected" }
95
96 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
97         host := strings.TrimSuffix(addr, ":443")
98         cfg := tls.Config{
99                 VerifyPeerCertificate: func(
100                         rawCerts [][]byte,
101                         verifiedChains [][]*x509.Certificate,
102                 ) error {
103                         return verifyCert(host, nil, rawCerts, verifiedChains)
104                 },
105         }
106         conn, dialErr := tls.Dial(network, addr, &cfg)
107         if dialErr != nil {
108                 if _, ok := dialErr.(ErrRejected); ok {
109                         return nil, dialErr
110                 }
111                 cfg.InsecureSkipVerify = true
112                 cfg.VerifyPeerCertificate = func(
113                         rawCerts [][]byte,
114                         verifiedChains [][]*x509.Certificate,
115                 ) error {
116                         return verifyCert(host, dialErr, rawCerts, verifiedChains)
117                 }
118                 var err error
119                 conn, err = tls.Dial(network, addr, &cfg)
120                 if err != nil {
121                         sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
122                         return nil, err
123                 }
124         }
125         connState := conn.ConnectionState()
126         sinkTLS <- fmt.Sprintf(
127                 "%s\t%s %s\t%s",
128                 strings.TrimSuffix(addr, ":443"),
129                 ucspi.TLSVersion(connState.Version),
130                 tls.CipherSuiteName(connState.CipherSuite),
131                 spkiHash(connState.PeerCertificates[0]),
132         )
133         return conn, nil
134 }
135
136 func roundTrip(w http.ResponseWriter, req *http.Request) {
137         if req.Method == http.MethodHead {
138                 http.Error(w, "go away", http.StatusMethodNotAllowed)
139                 return
140         }
141         sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
142         host := strings.TrimSuffix(req.URL.Host, ":443")
143         for _, spy := range SpyDomains {
144                 if strings.HasSuffix(host, spy) {
145                         http.NotFound(w, req)
146                         sinkOther <- fmt.Sprintf(
147                                 "%s %s\t%d\tspy one",
148                                 req.Method,
149                                 req.URL.String(),
150                                 http.StatusNotFound,
151                         )
152                         return
153                 }
154         }
155         if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
156                 req.URL.Host = "old.reddit.com"
157                 http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
158                 return
159         }
160         resp, err := transport.RoundTrip(req)
161         if err != nil {
162                 sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
163                 w.WriteHeader(http.StatusBadGateway)
164                 w.Write([]byte(err.Error()))
165                 return
166         }
167         for k, vs := range resp.Header {
168                 if k == "Location" || k == "Content-Type" || k == "Content-Length" {
169                         continue
170                 }
171                 for _, v := range vs {
172                         w.Header().Add(k, v)
173                 }
174         }
175         if req.Method == http.MethodGet {
176                 var redirType string
177                 switch resp.StatusCode {
178                 case http.StatusMovedPermanently, http.StatusPermanentRedirect:
179                         redirType = "permanent"
180                         goto Redir
181                 case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
182                         if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
183                                 goto NoRedir
184                         }
185                         redirType = "temporary"
186                 default:
187                         goto NoRedir
188                 }
189         Redir:
190                 resp.Body.Close()
191                 w.Header().Add("Content-Type", "text/html")
192                 w.WriteHeader(http.StatusOK)
193                 location := resp.Header.Get("Location")
194                 w.Write([]byte(
195                         fmt.Sprintf(
196                                 `<html>
197 <head><title>%d %s: %s redirection</title></head>
198 <body>Redirection to <a href="%s">%s</a>
199 </body></html>`,
200                                 resp.StatusCode, http.StatusText(resp.StatusCode),
201                                 redirType, location, location,
202                         )))
203                 sinkRedir <- fmt.Sprintf(
204                         "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
205                 )
206                 return
207         }
208 NoRedir:
209         for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
210                 if v := resp.Header.Get(h); v != "" {
211                         w.Header().Add(h, v)
212                 }
213         }
214         w.WriteHeader(resp.StatusCode)
215         n, err := io.Copy(w, resp.Body)
216         if err != nil {
217                 log.Printf("Error during %s: %+v\n", req.URL, err)
218         }
219         resp.Body.Close()
220         msg := fmt.Sprintf(
221                 "%s %s\t%s\t%s\t%s",
222                 req.Method,
223                 req.URL.String(),
224                 resp.Status,
225                 resp.Header.Get("Content-Type"),
226                 humanize.IBytes(uint64(n)),
227         )
228         if resp.StatusCode == http.StatusOK {
229                 sinkOK <- msg
230         } else {
231                 sinkOther <- msg
232         }
233 }
234
235 type Handler struct{}
236
237 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
238         if req.Method != http.MethodConnect {
239                 roundTrip(w, req)
240                 return
241         }
242         hj, ok := w.(http.Hijacker)
243         if !ok {
244                 log.Fatalln("no hijacking")
245         }
246         conn, _, err := hj.Hijack()
247         if err != nil {
248                 log.Fatalln(err)
249         }
250         defer conn.Close()
251         conn.Write([]byte(fmt.Sprintf(
252                 "%s %d %s\r\n\r\n",
253                 req.Proto,
254                 http.StatusOK, http.StatusText(http.StatusOK),
255         )))
256         host := strings.Split(req.Host, ":")[0]
257         hostCertsM.Lock()
258         keypair, ok := hostCerts[host]
259         if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
260                 keypair = newKeypair(host, caCert, caPrv)
261                 hostCerts[host] = keypair
262         }
263         hostCertsM.Unlock()
264         tlsConn := tls.Server(conn, &tls.Config{
265                 Certificates: []tls.Certificate{{
266                         Certificate: [][]byte{keypair.cert.Raw},
267                         PrivateKey:  keypair.prv,
268                 }},
269         })
270         if err = tlsConn.Handshake(); err != nil {
271                 log.Printf("TLS error %s: %+v\n", host, err)
272                 return
273         }
274         srv := http.Server{
275                 Handler:      &HTTPSHandler{host: req.Host},
276                 TLSNextProto: tlsNextProtoS,
277         }
278         err = srv.Serve(&SingleListener{conn: tlsConn})
279         if err != nil {
280                 if _, ok := err.(AlreadyAccepted); !ok {
281                         log.Printf("TLS serve error %s: %+v\n", host, err)
282                         return
283                 }
284         }
285 }
286
287 type HTTPSHandler struct {
288         host string
289 }
290
291 func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
292         req.URL.Scheme = "https"
293         req.URL.Host = h.host
294         roundTrip(w, req)
295 }
296
297 func main() {
298         crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
299         prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
300         bind := flag.String("bind", "[::1]:8080", "Bind address")
301         certs = flag.String("certs", "certs", "Directory with pinned certificates")
302         dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
303         fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
304         notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
305         flag.Parse()
306         log.SetFlags(log.Lshortfile)
307         fifoInit()
308
309         var err error
310         _, caCert, err = ucspi.CertificateFromFile(*crtPath)
311         if err != nil {
312                 log.Fatalln(err)
313         }
314         caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
315         if err != nil {
316                 log.Fatalln(err)
317         }
318
319         ln, err := net.Listen("tcp", *bind)
320         if err != nil {
321                 log.Fatalln(err)
322         }
323         srv := http.Server{
324                 Handler:      &Handler{},
325                 TLSNextProto: tlsNextProtoS,
326         }
327         srv.SetKeepAlivesEnabled(false)
328         log.Println("listening:", *bind)
329         if err := srv.Serve(ln); err != nil {
330                 log.Fatalln(err)
331         }
332 }