]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - main.go
Refactoring
[tofuproxy.git] / main.go
diff --git a/main.go b/main.go
deleted file mode 100644 (file)
index a48ec24..0000000
--- a/main.go
+++ /dev/null
@@ -1,412 +0,0 @@
-/*
-Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
-
-This program is free software: you can redistribute it and/or modify
-it under the terms of the GNU General Public License as published by
-the Free Software Foundation, version 3 of the License.
-
-This program is distributed in the hope that it will be useful,
-but WITHOUT ANY WARRANTY; without even the implied warranty of
-MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-GNU General Public License for more details.
-
-You should have received a copy of the GNU General Public License
-along with this program.  If not, see <http://www.gnu.org/licenses/>.
-*/
-
-package main
-
-import (
-       "context"
-       "crypto"
-       "crypto/tls"
-       "crypto/x509"
-       "flag"
-       "fmt"
-       "io"
-       "io/ioutil"
-       "log"
-       "net"
-       "net/http"
-       "os"
-       "os/exec"
-       "path/filepath"
-       "strings"
-       "time"
-
-       "github.com/dustin/go-humanize"
-       "go.cypherpunks.ru/ucspi"
-)
-
-var (
-       tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
-       caCert        *x509.Certificate
-       caPrv         crypto.PrivateKey
-       transport     = http.Transport{
-               DialTLSContext:    dialTLS,
-               ForceAttemptHTTP2: true,
-       }
-       sessionCache = tls.NewLRUClientSessionCache(1024)
-
-       CmdDWebP = "dwebp"
-       CmdDJXL  = "djxl"
-
-       imageExts = map[string]struct{}{
-               ".apng": {},
-               ".avif": {},
-               ".gif":  {},
-               ".heic": {},
-               ".jp2":  {},
-               ".jpeg": {},
-               ".jpg":  {},
-               ".jxl":  {},
-               ".mng":  {},
-               ".png":  {},
-               ".svg":  {},
-               ".tif":  {},
-               ".tiff": {},
-               ".webp": {},
-       }
-)
-
-func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
-       host := strings.TrimSuffix(addr, ":443")
-       cfg := tls.Config{
-               VerifyPeerCertificate: func(
-                       rawCerts [][]byte,
-                       verifiedChains [][]*x509.Certificate,
-               ) error {
-                       return verifyCert(host, nil, rawCerts, verifiedChains)
-               },
-               ClientSessionCache: sessionCache,
-               NextProtos:         []string{"h2", "http/1.1"},
-       }
-       conn, dialErr := tls.Dial(network, addr, &cfg)
-       if dialErr != nil {
-               if _, ok := dialErr.(ErrRejected); ok {
-                       return nil, dialErr
-               }
-               cfg.InsecureSkipVerify = true
-               cfg.VerifyPeerCertificate = func(
-                       rawCerts [][]byte,
-                       verifiedChains [][]*x509.Certificate,
-               ) error {
-                       return verifyCert(host, dialErr, rawCerts, verifiedChains)
-               }
-               var err error
-               conn, err = tls.Dial(network, addr, &cfg)
-               if err != nil {
-                       sinkErr <- fmt.Sprintf("%s\t%s", addr, dialErr.Error())
-                       return nil, err
-               }
-       }
-       connState := conn.ConnectionState()
-       if connState.DidResume {
-               sinkTLS <- fmt.Sprintf(
-                       "%s\t%s %s\t%s\t%s",
-                       strings.TrimSuffix(addr, ":443"),
-                       ucspi.TLSVersion(connState.Version),
-                       tls.CipherSuiteName(connState.CipherSuite),
-                       spkiHash(connState.PeerCertificates[0]),
-                       connState.NegotiatedProtocol,
-               )
-       }
-       return conn, nil
-}
-
-func roundTrip(w http.ResponseWriter, req *http.Request) {
-       if req.Method == http.MethodHead {
-               http.Error(w, "go away", http.StatusMethodNotAllowed)
-               return
-       }
-       sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
-       host := strings.TrimSuffix(req.URL.Host, ":443")
-
-       for _, spy := range SpyDomains {
-               if strings.HasSuffix(host, spy) {
-                       http.NotFound(w, req)
-                       sinkOther <- fmt.Sprintf(
-                               "%s %s\t%d\tspy one",
-                               req.Method,
-                               req.URL.String(),
-                               http.StatusNotFound,
-                       )
-                       return
-               }
-       }
-
-       if host == "www.reddit.com" {
-               req.URL.Host = "old.reddit.com"
-               http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
-               return
-       }
-
-       if host == "habrastorage.org" && strings.Contains(req.URL.Path, "r/w780q1") {
-               req.URL.Path = strings.Replace(req.URL.Path, "r/w780q1/", "", 1)
-               http.Redirect(w, req, req.URL.String(), http.StatusFound)
-               return
-       }
-
-       resp, err := transport.RoundTrip(req)
-       if err != nil {
-               sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
-               w.WriteHeader(http.StatusBadGateway)
-               w.Write([]byte(err.Error()))
-               return
-       }
-
-       for k, vs := range resp.Header {
-               if k == "Location" || k == "Content-Type" || k == "Content-Length" {
-                       continue
-               }
-               for _, v := range vs {
-                       w.Header().Add(k, v)
-               }
-       }
-
-       switch resp.Header.Get("Content-Type") {
-       case "application/font-woff", "application/font-sfnt":
-               // Those are deprecated types
-               fallthrough
-       case "font/otf", "font/ttf", "font/woff", "font/woff2":
-               http.NotFound(w, req)
-               sinkOther <- fmt.Sprintf(
-                       "%s %s\t%d\tfonts are not allowed",
-                       req.Method,
-                       req.URL.String(),
-                       http.StatusNotFound,
-               )
-               resp.Body.Close()
-               return
-       case "image/webp":
-               if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
-                       // My Xombrero
-                       break
-               }
-               tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
-               if err != nil {
-                       log.Fatalln(err)
-               }
-               defer tmpFd.Close()
-               defer os.Remove(tmpFd.Name())
-               defer resp.Body.Close()
-               if _, err = io.Copy(tmpFd, resp.Body); err != nil {
-                       log.Printf("Error during %s: %+v\n", req.URL, err)
-                       http.Error(w, err.Error(), http.StatusBadGateway)
-                       return
-               }
-               tmpFd.Close()
-               cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
-               data, err := cmd.Output()
-               if err != nil {
-                       http.Error(w, err.Error(), http.StatusBadGateway)
-                       return
-               }
-               w.Header().Add("Content-Type", "image/png")
-               w.WriteHeader(http.StatusOK)
-               w.Write(data)
-               sinkOther <- fmt.Sprintf(
-                       "%s %s\t%d\tWebP transcoded to PNG",
-                       req.Method,
-                       req.URL.String(),
-                       http.StatusOK,
-               )
-               return
-       case "image/jxl":
-               tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
-               if err != nil {
-                       log.Fatalln(err)
-               }
-               defer tmpFd.Close()
-               defer os.Remove(tmpFd.Name())
-               defer resp.Body.Close()
-               if _, err = io.Copy(tmpFd, resp.Body); err != nil {
-                       log.Printf("Error during %s: %+v\n", req.URL, err)
-                       http.Error(w, err.Error(), http.StatusBadGateway)
-                       return
-               }
-               tmpFd.Close()
-               dstFn := tmpFd.Name() + ".png"
-               cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
-               err = cmd.Run()
-               defer os.Remove(dstFn)
-               if err != nil {
-                       http.Error(w, err.Error(), http.StatusBadGateway)
-                       return
-               }
-               data, err := ioutil.ReadFile(dstFn)
-               if err != nil {
-                       http.Error(w, err.Error(), http.StatusBadGateway)
-                       return
-               }
-               w.Header().Add("Content-Type", "image/png")
-               w.WriteHeader(http.StatusOK)
-               w.Write(data)
-               sinkOther <- fmt.Sprintf(
-                       "%s %s\t%d\tJPEG XL transcoded to PNG",
-                       req.Method,
-                       req.URL.String(),
-                       http.StatusOK,
-               )
-               return
-       }
-
-       if req.Method == http.MethodGet {
-               var redirType string
-               switch resp.StatusCode {
-               case http.StatusMovedPermanently, http.StatusPermanentRedirect:
-                       redirType = "permanent"
-                       goto Redir
-               case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect:
-                       if strings.Contains(req.Header.Get("User-Agent"), "newsboat/") {
-                               goto NoRedir
-                       }
-                       if _, ok := imageExts[filepath.Ext(req.URL.Path)]; ok {
-                               goto NoRedir
-                       }
-                       redirType = "temporary"
-               default:
-                       goto NoRedir
-               }
-       Redir:
-               resp.Body.Close()
-               w.Header().Add("Content-Type", "text/html")
-               w.WriteHeader(http.StatusOK)
-               location := resp.Header.Get("Location")
-               w.Write([]byte(
-                       fmt.Sprintf(
-                               `<html><head><title>%d %s: %s redirection</title></head>
-<body>Redirection to <a href="%s">%s</a></body></html>`,
-                               resp.StatusCode, http.StatusText(resp.StatusCode),
-                               redirType, location, location,
-                       )))
-               sinkRedir <- fmt.Sprintf(
-                       "%s %s\t%s\t%s", req.Method, resp.Status, req.URL.String(), location,
-               )
-               return
-       }
-
-NoRedir:
-       for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
-               if v := resp.Header.Get(h); v != "" {
-                       w.Header().Add(h, v)
-               }
-       }
-       w.WriteHeader(resp.StatusCode)
-       n, err := io.Copy(w, resp.Body)
-       if err != nil {
-               log.Printf("Error during %s: %+v\n", req.URL, err)
-       }
-       resp.Body.Close()
-       msg := fmt.Sprintf(
-               "%s %s\t%s\t%s\t%s",
-               req.Method,
-               req.URL.String(),
-               resp.Status,
-               resp.Header.Get("Content-Type"),
-               humanize.IBytes(uint64(n)),
-       )
-       if resp.StatusCode == http.StatusOK {
-               sinkOK <- msg
-       } else {
-               sinkOther <- msg
-       }
-}
-
-type Handler struct{}
-
-func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       if req.Method != http.MethodConnect {
-               roundTrip(w, req)
-               return
-       }
-       hj, ok := w.(http.Hijacker)
-       if !ok {
-               log.Fatalln("no hijacking")
-       }
-       conn, _, err := hj.Hijack()
-       if err != nil {
-               log.Fatalln(err)
-       }
-       defer conn.Close()
-       conn.Write([]byte(fmt.Sprintf(
-               "%s %d %s\r\n\r\n",
-               req.Proto,
-               http.StatusOK, http.StatusText(http.StatusOK),
-       )))
-       host := strings.Split(req.Host, ":")[0]
-       hostCertsM.Lock()
-       keypair, ok := hostCerts[host]
-       if !ok || !keypair.cert.NotAfter.After(time.Now().Add(time.Hour)) {
-               keypair = newKeypair(host, caCert, caPrv)
-               hostCerts[host] = keypair
-       }
-       hostCertsM.Unlock()
-       tlsConn := tls.Server(conn, &tls.Config{
-               Certificates: []tls.Certificate{{
-                       Certificate: [][]byte{keypair.cert.Raw},
-                       PrivateKey:  keypair.prv,
-               }},
-       })
-       if err = tlsConn.Handshake(); err != nil {
-               log.Printf("TLS error %s: %+v\n", host, err)
-               return
-       }
-       srv := http.Server{
-               Handler:      &HTTPSHandler{host: req.Host},
-               TLSNextProto: tlsNextProtoS,
-       }
-       err = srv.Serve(&SingleListener{conn: tlsConn})
-       if err != nil {
-               if _, ok := err.(AlreadyAccepted); !ok {
-                       log.Printf("TLS serve error %s: %+v\n", host, err)
-                       return
-               }
-       }
-}
-
-type HTTPSHandler struct {
-       host string
-}
-
-func (h *HTTPSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
-       req.URL.Scheme = "https"
-       req.URL.Host = h.host
-       roundTrip(w, req)
-}
-
-func main() {
-       crtPath := flag.String("cert", "cert.pem", "Path to server X.509 certificate")
-       prvPath := flag.String("key", "prv.pem", "Path to server PKCS#8 private key")
-       bind := flag.String("bind", "[::1]:8080", "Bind address")
-       certs = flag.String("certs", "certs", "Directory with pinned certificates")
-       dnsSrv = flag.String("dns", "[::1]:53", "DNS server")
-       fifos = flag.String("fifos", "fifos", "Directory with FIFOs")
-       notai = flag.Bool("notai", false, "Do not prepend TAI64N to logs")
-       flag.Parse()
-       log.SetFlags(log.Lshortfile)
-       fifoInit()
-
-       var err error
-       _, caCert, err = ucspi.CertificateFromFile(*crtPath)
-       if err != nil {
-               log.Fatalln(err)
-       }
-       caPrv, err = ucspi.PrivateKeyFromFile(*prvPath)
-       if err != nil {
-               log.Fatalln(err)
-       }
-
-       ln, err := net.Listen("tcp", *bind)
-       if err != nil {
-               log.Fatalln(err)
-       }
-       srv := http.Server{
-               Handler:      &Handler{},
-               TLSNextProto: tlsNextProtoS,
-       }
-       log.Println("listening:", *bind)
-       if err := srv.Serve(ln); err != nil {
-               log.Fatalln(err)
-       }
-}