]> Sergey Matveev's repositories - pgpmaildecryptor.git/commitdiff
Better headers keeping intact
authorSergey Matveev <stargrave@stargrave.org>
Fri, 10 Jan 2025 10:47:31 +0000 (13:47 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Fri, 10 Jan 2025 10:47:31 +0000 (13:47 +0300)
README [new file with mode: 0644]
main.go

diff --git a/README b/README
new file mode 100644 (file)
index 0000000..5e60056
--- /dev/null
+++ b/README
@@ -0,0 +1,9 @@
+pgpmaildecryptor -- PGP-encrypted mail decryptor
+That utility reads RFC822 message from stdin, checks if it is
+PGP-encrypted one, calls "gpg" to decrypt it and outputs decrypted
+message to the stdout.
+
+It considers possible protected headers, removing non-protected ones
+from the original message.
+
+Arguments passed to the utility will be passed to to invoked gpg.
diff --git a/main.go b/main.go
index 341c921b72563a10317a978315c0a10231f9b3a556cdd8a9e84bf4bb0379041d..f64180088dcd7862a904b965c9d7cc127397b5b37c5f24fff8041095d92094df 100644 (file)
--- a/main.go
+++ b/main.go
@@ -6,6 +6,7 @@ import (
        "io"
        "log"
        "mime"
+       "mime/multipart"
        "net/mail"
        "net/textproto"
        "os"
@@ -13,37 +14,85 @@ import (
        "strings"
 )
 
+const (
+       ContentType        = "Content-Type"
+       MultipartEncrypted = "multipart/encrypted"
+       PGPEncrypted       = "application/pgp-encrypted"
+       OctetStream        = "application/octet-stream"
+       StatusPrefix       = "[GNUPG:] "
+       DecryptionOK       = "DECRYPTION_OKAY"
+)
+
+func readHdrLines(data []byte) (hdrs []string, values [][]string) {
+       var line string
+       s := bufio.NewScanner(bytes.NewReader(data))
+       for s.Scan() {
+               line = s.Text()
+               if line == "" {
+                       break
+               }
+               if line[0] == ' ' || line[0] == '\t' {
+                       values[len(values)-1] = append(values[len(values)-1], line)
+                       continue
+               }
+               hdrs = append(hdrs, textproto.CanonicalMIMEHeaderKey(
+                       line[:strings.Index(line, ":")]))
+               values = append(values, []string{line})
+       }
+       return
+}
+
 func main() {
        log.SetFlags(log.Lshortfile)
-       rawEnc, err := io.ReadAll(bufio.NewReader(os.Stdin))
+       encRaw, err := io.ReadAll(bufio.NewReader(os.Stdin))
        if err != nil {
                log.Fatal(err)
        }
-       msgEnc, err := mail.ReadMessage(bytes.NewReader(rawEnc))
+       msgEnc, err := mail.ReadMessage(bytes.NewReader(encRaw))
        if err != nil {
                log.Fatal(err)
        }
+       var part *multipart.Part
        {
-               ct := msgEnc.Header.Get("Content-Type")
-               if ct == "" {
-                       log.Fatal("no Content-Type")
+               var boundary string
+               {
+                       ct := msgEnc.Header.Get(ContentType)
+                       if ct == "" {
+                               log.Fatal("no", ContentType)
+                       }
+                       mt, params, errParse := mime.ParseMediaType(ct)
+                       if errParse != nil {
+                               log.Fatal(errParse)
+                       }
+                       if mt != MultipartEncrypted {
+                               log.Fatal("not", MultipartEncrypted)
+                       }
+                       if proto := params["protocol"]; proto != PGPEncrypted {
+                               log.Fatal("not", PGPEncrypted)
+                       }
+                       boundary = params["boundary"]
                }
-               mt, params, errParse := mime.ParseMediaType(ct)
-               if errParse != nil {
-                       log.Fatal(errParse)
+               mpr := multipart.NewReader(msgEnc.Body, boundary)
+               part, err = mpr.NextPart()
+               if err != nil {
+                       log.Fatal(err)
                }
-               if mt != "multipart/encrypted" {
-                       log.Fatal("not multipart/encrypted")
+               if part.Header.Get(ContentType) != PGPEncrypted {
+                       log.Fatal("wrong first part's", ContentType)
                }
-               if proto := params["protocol"]; proto != "application/pgp-encrypted" {
-                       log.Fatal("not application/pgp-encrypted")
+               part, err = mpr.NextPart()
+               if err != nil {
+                       log.Fatal(err)
+               }
+               if part.Header.Get(ContentType) != OctetStream {
+                       log.Fatal("wrong second part's", ContentType)
                }
        }
-       var rawDec []byte
+       var line string
+       var decRaw []byte
        {
                args := []string{"--batch", "--decrypt", "--status-fd", "3"}
-               args = append(args, os.Args[1:]...)
-               cmd := exec.Command("gpg", args...)
+               cmd := exec.Command("gpg", append(args, os.Args[1:]...)...)
                cmd.Stderr = os.Stderr
                var stdin io.WriteCloser
                stdin, err = cmd.StdinPipe()
@@ -61,13 +110,12 @@ func main() {
                        log.Fatal(err)
                }
                cmd.ExtraFiles = append(cmd.ExtraFiles, statusW)
-
                err = cmd.Start()
                if err != nil {
                        log.Fatal(err)
                }
                go func() {
-                       _, errStdin := io.Copy(stdin, msgEnc.Body)
+                       _, errStdin := io.Copy(stdin, part)
                        if err != nil {
                                log.Print(errStdin)
                        }
@@ -78,19 +126,18 @@ func main() {
                var goodDec bool
                go func() {
                        s := bufio.NewScanner(statusR)
-                       var t string
                        for s.Scan() {
-                               t = s.Text()
-                               if t == "[GNUPG:] DECRYPTION_OKAY" {
+                               line = s.Text()
+                               if line == StatusPrefix+DecryptionOK {
                                        goodDec = true
                                }
-                               log.Print(t)
+                               log.Print(line)
                        }
                        if s.Err() != nil {
                                log.Print(s.Err())
                        }
                }()
-               rawDec, err = io.ReadAll(stdout)
+               decRaw, err = io.ReadAll(stdout)
                if err != nil {
                        log.Fatal(err)
                }
@@ -99,37 +146,28 @@ func main() {
                        log.Fatal(err)
                }
                if !goodDec {
-                       log.Fatal("no DECRYPTION_OKAY received")
+                       log.Fatal("no", DecryptionOK, "received")
                }
        }
-       msgDec, err := mail.ReadMessage(bytes.NewReader(rawDec))
-       if err != nil {
-               log.Fatal(err)
-       }
        stdout := bufio.NewWriter(os.Stdout)
-       r := textproto.NewReader(bufio.NewReader(bytes.NewReader(rawEnc)))
-       var t string
-       var hdr string
-       for {
-               t, err = r.ReadContinuedLine()
-               if err != nil {
-                       log.Fatal(err)
-               }
-               if t == "" {
-                       break
-               }
-               hdr = t[:strings.Index(t, ":")]
-               if msgDec.Header.Get(hdr) != "" {
-                       continue
-               }
-               if _, err = stdout.WriteString(t); err != nil {
-                       log.Fatal(err)
-               }
-               if _, err = stdout.WriteString("\n"); err != nil {
-                       log.Fatal(err)
+       {
+               encHdrs, encValues := readHdrLines(encRaw)
+               decHdrs, _ := readHdrLines(decRaw)
+       EncHdr:
+               for idxEnc, encHdr := range encHdrs {
+                       for _, decHdr := range decHdrs {
+                               if decHdr == encHdr {
+                                       continue EncHdr
+                               }
+                       }
+                       for _, line := range encValues[idxEnc] {
+                               if _, err = stdout.WriteString(line + "\n"); err != nil {
+                                       log.Fatal(err)
+                               }
+                       }
                }
        }
-       if _, err = io.Copy(stdout, bytes.NewReader(rawDec)); err != nil {
+       if _, err = io.Copy(stdout, bytes.NewReader(decRaw)); err != nil {
                log.Fatal(err)
        }
        if err = stdout.Flush(); err != nil {