]> Sergey Matveev's repositories - mmc.git/commitdiff
Verify SPKI hash
authorSergey Matveev <stargrave@stargrave.org>
Mon, 15 Apr 2024 13:47:36 +0000 (16:47 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Mon, 15 Apr 2024 14:03:42 +0000 (17:03 +0300)
cmd/mmc/main.go
common.go

index 79a850ab48a4992096c347bbd54ab4903a0c5d69..069f0e41ef7b2b6ff5c65d47d4482e4c49b12a03 100644 (file)
@@ -19,6 +19,7 @@ package main
 import (
        "archive/tar"
        "bytes"
 import (
        "archive/tar"
        "bytes"
+       "crypto/tls"
        "encoding/json"
        "errors"
        "flag"
        "encoding/json"
        "errors"
        "flag"
@@ -26,6 +27,7 @@ import (
        "io"
        "io/fs"
        "log"
        "io"
        "io/fs"
        "log"
+       "net/http"
        "net/url"
        "os"
        "os/exec"
        "net/url"
        "os"
        "os/exec"
@@ -38,6 +40,7 @@ import (
        "time"
 
        "github.com/davecgh/go-spew/spew"
        "time"
 
        "github.com/davecgh/go-spew/spew"
+       "github.com/gorilla/websocket"
        "github.com/mattermost/mattermost-server/v6/model"
        "go.cypherpunks.ru/netrc"
        "go.stargrave.org/mmc"
        "github.com/mattermost/mattermost-server/v6/model"
        "go.cypherpunks.ru/netrc"
        "go.stargrave.org/mmc"
@@ -69,6 +72,7 @@ func mkFifo(pth string) {
 
 func main() {
        entrypoint := flag.String("entrypoint", mmc.GetEntrypoint(), "Entrypoint")
 
 func main() {
        entrypoint := flag.String("entrypoint", mmc.GetEntrypoint(), "Entrypoint")
+       spkiHash := flag.String("spki", mmc.GetSPKIHash(), "Entrypoint's SPKI hash")
        notifyCmd := flag.String("notify", "cmd/notify", "Path to notification handler")
        heartbeatCh := flag.String("heartbeat-ch", "town-square", "Channel for heartbeating")
        flag.Parse()
        notifyCmd := flag.String("notify", "cmd/notify", "Path to notification handler")
        heartbeatCh := flag.String("heartbeat-ch", "town-square", "Channel for heartbeating")
        flag.Parse()
@@ -96,6 +100,20 @@ func main() {
                log.Fatalln("no credentials found for:", entrypointURL.Hostname())
        }
        c := model.NewAPIv4Client(*entrypoint)
                log.Fatalln("no credentials found for:", entrypointURL.Hostname())
        }
        c := model.NewAPIv4Client(*entrypoint)
+       c.HTTPClient.Transport = &http.Transport{
+               Proxy:                 http.ProxyFromEnvironment,
+               ForceAttemptHTTP2:     true,
+               MaxIdleConns:          100,
+               IdleConnTimeout:       90 * time.Second,
+               TLSHandshakeTimeout:   10 * time.Second,
+               ExpectContinueTimeout: 1 * time.Second,
+               TLSClientConfig: &tls.Config{
+                       ServerName:            entrypointURL.Hostname(),
+                       InsecureSkipVerify:    true,
+                       VerifyPeerCertificate: mmc.NewVerifyPeerCertificate(*spkiHash),
+               },
+       }
+
        c.Login(login, password)
        me, resp, err := c.GetMe("")
        if err != nil {
        c.Login(login, password)
        me, resp, err := c.GetMe("")
        if err != nil {
@@ -452,7 +470,15 @@ func main() {
        default:
                log.Println("unhandled scheme:", entrypointURL.Scheme)
        }
        default:
                log.Println("unhandled scheme:", entrypointURL.Scheme)
        }
-       wc, err := model.NewWebSocketClient4(entrypointURL.String(), c.AuthToken)
+       wc, err := model.NewWebSocketClient4WithDialer(
+               &websocket.Dialer{
+                       TLSClientConfig: &tls.Config{
+                               ServerName:            entrypointURL.Hostname(),
+                               InsecureSkipVerify:    true,
+                               VerifyPeerCertificate: mmc.NewVerifyPeerCertificate(*spkiHash),
+                       },
+               }, entrypointURL.String(), c.AuthToken,
+       )
        if err != nil {
                log.Fatalln(err)
        }
        if err != nil {
                log.Fatalln(err)
        }
index f4d6a9d74b6da787780a6c75fafc0e091ce1e758..ce84af61c5cfb6fcafc871c25a499b8ad6a21826 100644 (file)
--- a/common.go
+++ b/common.go
 package mmc
 
 import (
 package mmc
 
 import (
+       "crypto/sha256"
+       "crypto/x509"
+       "encoding/hex"
+       "errors"
        "os"
        "strings"
        "time"
        "os"
        "strings"
        "time"
@@ -109,3 +113,30 @@ func GetEntrypoint() string {
        }
        return s
 }
        }
        return s
 }
+
+func GetSPKIHash() string {
+       s := os.Getenv("MMC_SPKI")
+       if s == "" {
+               return "deadbeef"
+       }
+       return s
+}
+
+func NewVerifyPeerCertificate(hashExpected string) func(
+       rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
+) error {
+       return func(
+               rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
+       ) error {
+               cer, err := x509.ParseCertificate(rawCerts[0])
+               if err != nil {
+                       return err
+               }
+               spki := cer.RawSubjectPublicKeyInfo
+               hsh := sha256.Sum256(spki)
+               if hashExpected != hex.EncodeToString(hsh[:]) {
+                       return errors.New("server certificate's SPKI hash mismatch")
+               }
+               return nil
+       }
+}