]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - main.go
TLS session resumption support
[tofuproxy.git] / main.go
diff --git a/main.go b/main.go
index 6521cb618149a91448720480785f78b641a5e167..69448ce2a66f53860064e9bde6ba3c0b4ed0f73c 100644 (file)
--- a/main.go
+++ b/main.go
@@ -17,22 +17,20 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.
 package main
 
 import (
-       "bytes"
        "context"
        "crypto"
-       "crypto/sha256"
        "crypto/tls"
        "crypto/x509"
-       "encoding/hex"
        "flag"
        "fmt"
        "io"
+       "io/ioutil"
        "log"
        "net"
        "net/http"
+       "os"
        "os/exec"
        "strings"
-       "sync"
        "time"
 
        "github.com/dustin/go-humanize"
@@ -41,58 +39,19 @@ import (
 
 var (
        tlsNextProtoS = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
-       tlsNextProtoC = make(map[string]func(string, *tls.Conn) http.RoundTripper)
        caCert        *x509.Certificate
        caPrv         crypto.PrivateKey
-       certs         *string
-       dnsSrv        *string
        transport     = http.Transport{
-               ForceAttemptHTTP2:   false,
-               DisableKeepAlives:   true,
-               MaxIdleConnsPerHost: 2,
-               TLSNextProto:        tlsNextProtoC,
-               DialTLSContext:      dialTLS,
+               ForceAttemptHTTP2: false,
+               TLSNextProto:      make(map[string]func(string, *tls.Conn) http.RoundTripper),
+               DialTLSContext:    dialTLS,
        }
+       sessionCache = tls.NewLRUClientSessionCache(1024)
 
-       accepted  = make(map[string]string)
-       acceptedM sync.RWMutex
-       rejected  = make(map[string]string)
-       rejectedM sync.RWMutex
+       CmdDWebP = "dwebp"
+       CmdDJXL  = "djxl"
 )
 
-func spkiHash(cert *x509.Certificate) string {
-       hsh := sha256.Sum256(cert.RawSubjectPublicKeyInfo)
-       return hex.EncodeToString(hsh[:])
-}
-
-func certInfo(certRaw []byte) string {
-       cmd := exec.Command("certtool", "--certificate-info", "--inder")
-       cmd.Stdin = bytes.NewReader(certRaw)
-       out, err := cmd.Output()
-       if err == nil {
-               return string(out)
-       }
-       return err.Error()
-}
-
-func acceptedAdd(addr, h string) {
-       acceptedM.Lock()
-       accepted[addr] = h
-       acceptedM.Unlock()
-}
-
-func rejectedAdd(addr, h string) {
-       rejectedM.Lock()
-       rejected[addr] = h
-       rejectedM.Unlock()
-}
-
-type ErrRejected struct {
-       addr string
-}
-
-func (err ErrRejected) Error() string { return err.addr + " was rejected" }
-
 func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
        host := strings.TrimSuffix(addr, ":443")
        cfg := tls.Config{
@@ -102,6 +61,7 @@ func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
                ) error {
                        return verifyCert(host, nil, rawCerts, verifiedChains)
                },
+               ClientSessionCache: sessionCache,
        }
        conn, dialErr := tls.Dial(network, addr, &cfg)
        if dialErr != nil {
@@ -123,13 +83,17 @@ func dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
                }
        }
        connState := conn.ConnectionState()
-       sinkTLS <- fmt.Sprintf(
+       msg := fmt.Sprintf(
                "%s\t%s %s\t%s",
                strings.TrimSuffix(addr, ":443"),
                ucspi.TLSVersion(connState.Version),
                tls.CipherSuiteName(connState.CipherSuite),
                spkiHash(connState.PeerCertificates[0]),
        )
+       if connState.DidResume {
+               msg += "\tresumed"
+       }
+       sinkTLS <- msg
        return conn, nil
 }
 
@@ -140,6 +104,7 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
        }
        sinkReq <- fmt.Sprintf("%s %s", req.Method, req.URL.String())
        host := strings.TrimSuffix(req.URL.Host, ":443")
+
        for _, spy := range SpyDomains {
                if strings.HasSuffix(host, spy) {
                        http.NotFound(w, req)
@@ -152,11 +117,13 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                        return
                }
        }
+
        if strings.HasPrefix(req.URL.Host, "www.reddit.com") {
                req.URL.Host = "old.reddit.com"
                http.Redirect(w, req, req.URL.String(), http.StatusMovedPermanently)
                return
        }
+
        resp, err := transport.RoundTrip(req)
        if err != nil {
                sinkErr <- fmt.Sprintf("%s\t%s", req.URL.Host, err.Error())
@@ -164,6 +131,7 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                w.Write([]byte(err.Error()))
                return
        }
+
        for k, vs := range resp.Header {
                if k == "Location" || k == "Content-Type" || k == "Content-Length" {
                        continue
@@ -172,6 +140,94 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                        w.Header().Add(k, v)
                }
        }
+
+       switch resp.Header.Get("Content-Type") {
+       case "application/font-woff", "application/font-sfnt":
+               // Those are deprecated types
+               fallthrough
+       case "font/otf", "font/ttf", "font/woff", "font/woff2":
+               http.NotFound(w, req)
+               sinkOther <- fmt.Sprintf(
+                       "%s %s\t%d\tfonts are not allowed",
+                       req.Method,
+                       req.URL.String(),
+                       http.StatusNotFound,
+               )
+               resp.Body.Close()
+               return
+       case "image/webp":
+               if strings.Contains(req.Header.Get("User-Agent"), "AppleWebKit/538.15") {
+                       // My Xombrero
+                       break
+               }
+               tmpFd, err := ioutil.TempFile("", "tofuproxy.*.webp")
+               if err != nil {
+                       log.Fatalln(err)
+               }
+               defer tmpFd.Close()
+               defer os.Remove(tmpFd.Name())
+               defer resp.Body.Close()
+               if _, err = io.Copy(tmpFd, resp.Body); err != nil {
+                       log.Printf("Error during %s: %+v\n", req.URL, err)
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               tmpFd.Close()
+               cmd := exec.Command(CmdDWebP, tmpFd.Name(), "-o", "-")
+               data, err := cmd.Output()
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               w.Header().Add("Content-Type", "image/png")
+               w.WriteHeader(http.StatusOK)
+               w.Write(data)
+               sinkOther <- fmt.Sprintf(
+                       "%s %s\t%d\tWebP transcoded to PNG",
+                       req.Method,
+                       req.URL.String(),
+                       http.StatusOK,
+               )
+               return
+       case "image/jxl":
+               tmpFd, err := ioutil.TempFile("", "tofuproxy.*.jxl")
+               if err != nil {
+                       log.Fatalln(err)
+               }
+               defer tmpFd.Close()
+               defer os.Remove(tmpFd.Name())
+               defer resp.Body.Close()
+               if _, err = io.Copy(tmpFd, resp.Body); err != nil {
+                       log.Printf("Error during %s: %+v\n", req.URL, err)
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               tmpFd.Close()
+               dstFn := tmpFd.Name() + ".png"
+               cmd := exec.Command(CmdDJXL, tmpFd.Name(), dstFn)
+               err = cmd.Run()
+               defer os.Remove(dstFn)
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               data, err := ioutil.ReadFile(dstFn)
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusBadGateway)
+                       return
+               }
+               w.Header().Add("Content-Type", "image/png")
+               w.WriteHeader(http.StatusOK)
+               w.Write(data)
+               sinkOther <- fmt.Sprintf(
+                       "%s %s\t%d\tJPEG XL transcoded to PNG",
+                       req.Method,
+                       req.URL.String(),
+                       http.StatusOK,
+               )
+               return
+       }
+
        if req.Method == http.MethodGet {
                var redirType string
                switch resp.StatusCode {
@@ -205,6 +261,7 @@ func roundTrip(w http.ResponseWriter, req *http.Request) {
                )
                return
        }
+
 NoRedir:
        for _, h := range []string{"Location", "Content-Type", "Content-Length"} {
                if v := resp.Header.Get(h); v != "" {
@@ -324,7 +381,6 @@ func main() {
                Handler:      &Handler{},
                TLSNextProto: tlsNextProtoS,
        }
-       srv.SetKeepAlivesEnabled(false)
        log.Println("listening:", *bind)
        if err := srv.Serve(ln); err != nil {
                log.Fatalln(err)