README | 6 +++--- doc/features.rst | 6 +++--- doc/news.rst | 1 + pyderasn.py | 300 +++++++++++++++++++++++++++++++++++++++++++++++++++-- tests/test_cms.py | 92 +++++++++++++++++++++++++++++++++++++++++++++-------- tests/test_crl.py | 19 ++++++++++++++----- tests/test_pyderasn.py | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/README b/README index fba7f6049df8ef938de7d2fe5c4fbeb8765bb1133003dea3e73ac6666ef8fa77..9c397c227756a7398892f990e1b390828fbfe6a4271b2b1e6039522828feba3b 100644 --- a/README +++ b/README @@ -19,9 +19,9 @@ * Ability to know exact decoded objects offset and lengths in the binary * Ability to allow BER-encoded data with knowing if any of specified field has either DER or BER encoding (or possibly indefinite-length encoding) -* Ability to use mmap-ed files, memoryviews, iterators and CER encoder - dealing with the writer, giving ability to create huge ASN.1 encoded - files without storing all the data in the memory first +* Ability to use mmap-ed files, memoryviews, iterators, 2-pass DER + encoding mode and CER encoder dealing with the writer, giving ability + to create huge ASN.1 encoded files with very little memory footprint * Ability to decode files in event generation mode, without the need to keep all the data and decoded structures in the memory * __slots__, copy.copy() friendliness diff --git a/doc/features.rst b/doc/features.rst index e537f9e5d8a195bddb178a9f25346389d46db4cd7645f31d207eac5a45dcbf03..5299899da983f985e14235292ad9ee631438c451a2a36e8bf76b9e7a74ea04ed 100644 --- a/doc/features.rst +++ b/doc/features.rst @@ -38,9 +38,9 @@ `CMS `__ structures allow BER encoding for the whole message, except for ``SignedAttributes`` -- you can easily verify your CMS satisfies that requirement -* Ability to use mmap-ed files, memoryviews, iterators and CER encoder - dealing with the writer, giving ability to create huge ASN.1 encoded - files without storing all the data in the memory first +* Ability to use mmap-ed files, memoryviews, iterators, 2-pass DER + encoding mode and CER encoder dealing with the writer, giving ability + to create huge ASN.1 encoded files with very little memory footprint * Ability to decode files in event generation mode, without the need to keep all the data and decoded structures (that takes huge quantity of memory in all known ASN.1 libraries) in the memory diff --git a/doc/news.rst b/doc/news.rst index 751e8046ee81060eba318efd5bcede6e951a569f3a6ba2d07f5a97c9c64c15d5..f30ea7c54ba9eb5f2e310ac1170754f08d91fe7aeb64b419c2c3f1e513935996 100644 --- a/doc/news.rst +++ b/doc/news.rst @@ -7,6 +7,7 @@ 7.2 --- * Restored workability of some command line options +* 2-pass DER encoding mode with very little memory footprint .. _release7.1: diff --git a/pyderasn.py b/pyderasn.py index 262e48144addbb3ef72a8ff561ac3b05dc925ea1a40c33d2944c2531a8f6536e..7b32e2b823b1e2319de8d76f4b1146d1535cef0ed8ec01c42cfc3dc13946436c 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -829,6 +829,8 @@ copy the payload (without BER/CER encoding interleaved overhead) in it. Virtually it won't take memory more than for keeping small structures and 1 KB binary chunks. +.. _seqof-iterators: + SEQUENCE OF iterators _____________________ @@ -844,6 +846,54 @@ generator taking necessary data from the database and giving the ``RevokedCertificate`` objects. Only binary representation of that objects will take memory during DER encoding. +2-pass DER encoding +------------------- + +There is ability to do 2-pass encoding to DER, writing results directly +to specified writer (buffer, file, whatever). It could be 1.5+ times +slower than ordinary encoding, but it takes little memory for 1st pass +state storing. For example, 1st pass state for CACert.org's CRL with +~416K of certificate entries takes nearly 3.5 MB of memory. +``SignedData`` with several gigabyte ``EncapsulatedContentInfo`` takes +nearly 0.5 KB of memory. + +If you use :ref:`mmap-ed ` memoryviews, :ref:`SEQUENCE OF +iterators ` and write directly to opened file, then +there is very small memory footprint. + +1st pass traverses through all the objects of the structure and returns +the size of DER encoded structure, together with 1st pass state object. +That state contains precalculated lengths for various objects inside the +structure. + +:: + + fulllen, state = obj.encode1st() + +2nd pass takes the writer and 1st pass state. It traverses through all +the objects again, but writes their encoded representation to the writer. + +:: + + opener = io.open if PY2 else open + with opener("result", "wb") as fd: + obj.encode2nd(fd.write, iter(state)) + +.. warning:: + + You **MUST NOT** use 1st pass state if anything is changed in the + objects. It is intended to be used immediately after 1st pass is + done! + +If you use :ref:`SEQUENCE OF iterators `, then you +have to reinitialize the values after the 1st pass. And you **have to** +be sure that the iterator gives exactly the same values as previously. +Yes, you have to run your iterator twice -- because this is two pass +encoding mode. + +If you want to encode to the memory, then you can use convenient +:py:func:`pyderasn.encode2pass` helper. + Base Obj -------- .. autoclass:: pyderasn.Obj @@ -955,6 +1005,7 @@ .. autofunction:: pyderasn.abs_decode_path .. autofunction:: pyderasn.agg_octet_string .. autofunction:: pyderasn.colonize_hex +.. autofunction:: pyderasn.encode2pass .. autofunction:: pyderasn.encode_cer .. autofunction:: pyderasn.file_mmaped .. autofunction:: pyderasn.hexenc @@ -1101,6 +1152,7 @@ from mmap import PROT_READ from operator import attrgetter from string import ascii_letters from string import digits +from sys import maxsize as sys_maxsize from sys import version_info from unicodedata import category as unicat @@ -1138,6 +1190,7 @@ "BoundsError", "Choice", "DecodeError", "DecodePathDefBy", + "encode2pass", "encode_cer", "Enumerated", "ExceedingData", @@ -1523,6 +1576,28 @@ LEN0 = len_encode(0) LEN1 = len_encode(1) LEN1K = len_encode(1000) + + +def len_size(l): + """How many bytes length field will take + """ + if l < 128: + return 1 + if l < 256: # 1 << 8 + return 2 + if l < 65536: # 1 << 16 + return 3 + if l < 16777216: # 1 << 24 + return 4 + if l < 4294967296: # 1 << 32 + return 5 + if l < 1099511627776: # 1 << 40 + return 6 + if l < 281474976710656: # 1 << 48 + return 7 + if l < 72057594037927936: # 1 << 56 + return 8 + raise OverflowError("too big length") def write_full(writer, data): @@ -1543,6 +1618,17 @@ raise ValueError("can not write to buf") written += n +# If it is 64-bit system, then use compact 64-bit array of unsigned +# longs. Use an ordinary list with universal integers otherwise, that +# is slower. +if sys_maxsize > 2 ** 32: + def state_2pass_new(): + return array("L") +else: + def state_2pass_new(): + return [] + + ######################################################################## # Base class ######################################################################## @@ -1701,9 +1787,18 @@ def _encode(self): # pragma: no cover raise NotImplementedError() + def _encode_cer(self, writer): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): # pragma: no cover yield NotImplemented + def _encode1st(self, state): + raise NotImplementedError() + + def _encode2nd(self, writer, state_iter): + raise NotImplementedError() + def encode(self): """DER encode the structure @@ -1714,6 +1809,36 @@ if self._expl is None: return raw return b"".join((self._expl, len_encode(len(raw)), raw)) + def encode1st(self, state=None): + """Do the 1st pass of 2-pass encoding + + :rtype: (int, array("L")) + :returns: full length of encoded data and precalculated various + objects lengths + """ + if state is None: + state = state_2pass_new() + if self._expl is None: + return self._encode1st(state) + state.append(0) + idx = len(state) - 1 + vlen, _ = self._encode1st(state) + state[idx] = vlen + fulllen = len(self._expl) + len_size(vlen) + vlen + return fulllen, state + + def encode2nd(self, writer, state_iter): + """Do the 2nd pass of 2-pass encoding + + :param writer: must comply with ``io.RawIOBase.write`` behaviour + :param state_iter: iterator over the 1st pass state (``iter(state)``) + """ + if self._expl is None: + self._encode2nd(writer, state_iter) + else: + write_full(writer, self._expl + len_encode(next(state_iter))) + self._encode2nd(writer, state_iter) + def encode_cer(self, writer): """CER encode the structure to specified writer @@ -1730,9 +1855,6 @@ else: self._encode_cer(writer) if self._expl is not None: write_full(writer, EOC) - - def _encode_cer(self, writer): - write_full(writer, self._encode()) def hexencode(self): """Do hexadecimal encoded :py:meth:`pyderasn.Obj.encode` @@ -2045,6 +2167,17 @@ obj.encode_cer(buf.write) return buf.getvalue() +def encode2pass(obj): + """Encode (2-pass mode) to DER in memory buffer + + :returns bytes: memory buffer contents + """ + buf = BytesIO() + _, state = obj.encode1st() + obj.encode2nd(buf.write, iter(state)) + return buf.getvalue() + + class DecodePathDefBy(object): """DEFINED BY representation inside decode path """ @@ -2464,6 +2597,13 @@ def _encode(self): self._assert_ready() return b"".join((self.tag, LEN1, (b"\xFF" if self._value else b"\x00"))) + def _encode1st(self, state): + return len(self.tag) + 2, state + + def _encode2nd(self, writer, state_iter): + self._assert_ready() + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -2755,7 +2895,7 @@ optional=self.optional if optional is None else optional, _specs=self.specs, ) - def _encode(self): + def _encode_payload(self): self._assert_ready() value = self._value if PY2: @@ -2792,8 +2932,20 @@ except OverflowError: bytes_len += 1 else: break + return octets return b"".join((self.tag, len_encode(len(octets)), octets)) + def _encode(self): + octets = self._encode_payload() + return b"".join((self.tag, len_encode(len(octets)), octets)) + + def _encode1st(self, state): + l = len(self._encode_payload()) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -3178,6 +3330,21 @@ int2byte((8 - bit_len % 8) % 8), octets, )) + def _encode1st(self, state): + self._assert_ready() + _, octets = self._value + l = len(octets) + 1 + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + bit_len, octets = self._value + write_full(writer, b"".join(( + self.tag, + len_encode(len(octets) + 1), + int2byte((8 - bit_len % 8) % 8), + ))) + write_full(writer, octets) + def _encode_cer(self, writer): bit_len, octets = self._value if len(octets) + 1 <= 1000: @@ -3629,6 +3796,16 @@ len_encode(len(self._value)), self._value, )) + def _encode1st(self, state): + self._assert_ready() + l = len(self._value) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + value = self._value + write_full(writer, self.tag + len_encode(len(value))) + write_full(writer, value) + def _encode_cer(self, writer): octets = self._value if len(octets) <= 1000: @@ -3987,6 +4164,12 @@ def _encode(self): return self.tag + LEN0 + def _encode1st(self, state): + return len(self.tag) + 1, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + LEN0) + def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: t, _, lv = tag_strip(tlv) @@ -4239,7 +4422,7 @@ default=self.default if default is None else default, optional=self.optional if optional is None else optional, ) - def _encode(self): + def _encode_octets(self): self._assert_ready() value = self._value first_value = value[1] @@ -4255,8 +4438,18 @@ raise RuntimeError("invalid arc is stored") octets = [zero_ended_encode(first_value)] for arc in value[2:]: octets.append(zero_ended_encode(arc)) - v = b"".join(octets) + return b"".join(octets) + + def _encode(self): + v = self._encode_octets() return b"".join((self.tag, len_encode(len(v)), v)) + + def _encode1st(self, state): + l = len(self._encode_octets()) + return len(self.tag) + len_size(l) + l, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) def _decode(self, tlv, offset, decode_path, ctx, tag_only, evgen_mode): try: @@ -5017,6 +5210,13 @@ def _encode(self): self._assert_ready() return b"".join((self.tag, LEN_LEN_YYMMDDHHMMSSZ, self._encode_time())) + def _encode1st(self, state): + return len(self.tag) + LEN_YYMMDDHHMMSSZ_WITH_LEN, state + + def _encode2nd(self, writer, state_iter): + self._assert_ready() + write_full(writer, self._encode()) + def _encode_cer(self, writer): write_full(writer, self._encode()) @@ -5188,6 +5388,14 @@ if value.microsecond > 0: encoded = self._encode_time() return b"".join((self.tag, len_encode(len(encoded)), encoded)) return b"".join((self.tag, LEN_LEN_YYYYMMDDHHMMSSZ, self._encode_time())) + + def _encode1st(self, state): + self._assert_ready() + vlen = len(self._encode_time()) + return len(self.tag) + len_size(vlen) + vlen, state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self._encode()) class GraphicString(CommonString): @@ -5433,6 +5641,13 @@ def _encode(self): self._assert_ready() return self._value[1].encode() + def _encode1st(self, state): + self._assert_ready() + return self._value[1].encode1st(state) + + def _encode2nd(self, writer, state_iter): + self._value[1].encode2nd(writer, state_iter) + def _encode_cer(self, writer): self._assert_ready() self._value[1].encode_cer(writer) @@ -5701,6 +5916,20 @@ if value.__class__ == binary_type: return value return value.encode() + def _encode1st(self, state): + self._assert_ready() + value = self._value + if value.__class__ == binary_type: + return len(value), state + return value.encode1st(state) + + def _encode2nd(self, writer, state_iter): + value = self._value + if value.__class__ == binary_type: + write_full(writer, value) + else: + value.encode2nd(writer, state_iter) + def _encode_cer(self, writer): self._assert_ready() value = self._value @@ -5859,7 +6088,19 @@ **NAMEDTUPLE_KWARGS ) -class Sequence(Obj): +class SequenceEncode1stMixing(object): + def _encode1st(self, state): + state.append(0) + idx = len(state) - 1 + vlen = 0 + for v in self._values_for_encoding(): + l, _ = v.encode1st(state) + vlen += l + state[idx] = vlen + return len(self.tag) + len_size(vlen) + vlen, state + + +class Sequence(SequenceEncode1stMixing, Obj): """``SEQUENCE`` structure type You have to make specification of sequence:: @@ -6106,6 +6347,11 @@ def _encode(self): v = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(v)), v)) + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + for v in self._values_for_encoding(): + v.encode2nd(writer, state_iter) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) for v in self._values_for_encoding(): @@ -6350,7 +6596,7 @@ for pp in self.pps_lenindef(decode_path): yield pp -class Set(Sequence): +class Set(Sequence, SequenceEncode1stMixing): """``SET`` structure type Its usage is identical to :py:class:`pyderasn.Sequence`. @@ -6552,7 +6798,7 @@ **NAMEDTUPLE_KWARGS ) -class SequenceOf(Obj): +class SequenceOf(SequenceEncode1stMixing, Obj): """``SEQUENCE OF`` sequence type For that kind of type you must specify the object it will carry on @@ -6781,6 +7027,31 @@ else: value = b"".join(v.encode() for v in self._values_for_encoding()) return b"".join((self.tag, len_encode(len(value)), value)) + def _encode1st(self, state): + state = super(SequenceOf, self)._encode1st(state) + if hasattr(self._value, NEXT_ATTR_NAME): + self._value = [] + return state + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + iterator = hasattr(self._value, NEXT_ATTR_NAME) + if iterator: + values_count = 0 + class_expected = self.spec.__class__ + values_for_encoding = self._values_for_encoding() + self._value = [] + for v in values_for_encoding: + if not isinstance(v, class_expected): + raise InvalidValueType((class_expected,)) + v.encode2nd(writer, state_iter) + values_count += 1 + if not self._bound_min <= values_count <= self._bound_max: + raise BoundsError(self._bound_min, values_count, self._bound_max) + else: + for v in self._values_for_encoding(): + v.encode2nd(writer, state_iter) + def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) iterator = hasattr(self._value, NEXT_ATTR_NAME) @@ -7003,6 +7274,17 @@ def _encode(self): v = b"".join(sorted(v.encode() for v in self._values_for_encoding())) return b"".join((self.tag, len_encode(len(v)), v)) + + def _encode2nd(self, writer, state_iter): + write_full(writer, self.tag + len_encode(next(state_iter))) + values = [] + for v in self._values_for_encoding(): + buf = BytesIO() + v.encode2nd(buf.write, state_iter) + values.append(buf.getvalue()) + values.sort() + for v in values: + write_full(writer, v) def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) diff --git a/tests/test_cms.py b/tests/test_cms.py index 679fae4f7ec17299d9496bbd32eea918f6ae93495a47931b53cb54f301bf7f7d..3a83c1cc67bdc4722d514296b53918495e1ad2c459b4adab249071d2c7c61d2a 100644 --- a/tests/test_cms.py +++ b/tests/test_cms.py @@ -22,6 +22,7 @@ from os import environ from os import remove from os import urandom from subprocess import call +from sys import getsizeof from tempfile import NamedTemporaryFile from time import time from unittest import skipIf @@ -276,6 +277,10 @@ ))])), ))))), )) cms_path = self.tmpfile() + _, state = ci.encode1st() + with io_open(cms_path, "wb") as fd: + ci.encode2nd(fd.write, iter(state)) + self.verify(cert_path, cms_path) with io_open(cms_path, "wb") as fd: ci.encode_cer(fd.write) self.verify(cert_path, cms_path) @@ -290,29 +295,30 @@ buf = BytesIO() agg_octet_string(evgens, ("encapContentInfo", "eContent"), raw, buf.write) self.assertSequenceEqual(buf.getvalue(), data) + def create_huge_file(self): + rnd = urandom(1<<20) + data_path = self.tmpfile() + start = time() + with open(data_path, "wb") as fd: + for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))): + # dgst.update(rnd) + fd.write(rnd) + print("data file written", time() - start) + return file_mmaped(open(data_path, "rb")) + @skipIf(PY2, "no mmaped memoryview support in PY2") @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set") - def test_huge(self): + def test_huge_cer(self): """Huge CMS test Environment variable PYDERASN_TEST_CMS_HUGE tells how many MiBs data to sign. Pay attention that openssl cms is unable to do stream verification and eats huge amounts (several times more, - that CMS itself) of memory. + than CMS itself) of memory. """ + data_raw = self.create_huge_file() key_path, cert_path, cert, skid = self.keypair() - rnd = urandom(1<<20) - data_path = self.tmpfile() - start = time() - with open(data_path, "wb") as fd: - for _ in six_xrange(int(environ.get("PYDERASN_TEST_CMS_HUGE"))): - # dgst.update(rnd) - fd.write(rnd) - print("data file written", time() - start) - data_fd = open(data_path, "rb") - data_raw = file_mmaped(data_fd) - - from sys import getallocatedblocks + from sys import getallocatedblocks # PY2 does not have it mem_start = getallocatedblocks() start = time() eci = EncapsulatedContentInfo(( @@ -376,3 +382,61 @@ with io_open(cms_path, "wb") as fd: ci.encode_cer(fd.write) print("CMS written", time() - start) self.verify(cert_path, cms_path) + + @skipIf(PY2, "no mmaped memoryview support in PY2") + @skipIf("PYDERASN_TEST_CMS_HUGE" not in environ, "PYDERASN_TEST_CMS_HUGE is not set") + def test_huge_der_2pass(self): + """Same test as above, but 2pass DER encoder and just signature verification + """ + data_raw = self.create_huge_file() + key_path, cert_path, cert, skid = self.keypair() + from sys import getallocatedblocks + mem_start = getallocatedblocks() + dgst = sha512(data_raw).digest() + start = time() + eci = EncapsulatedContentInfo(( + ("eContentType", ContentType(id_data)), + ("eContent", OctetString(data_raw)), + )) + signed_attrs = SignedAttributes([ + Attribute(( + ("attrType", id_pkcs9_at_contentType), + ("attrValues", AttributeValues([AttributeValue(id_data)])), + )), + Attribute(( + ("attrType", id_pkcs9_at_messageDigest), + ("attrValues", AttributeValues([AttributeValue(OctetString(dgst))])), + )), + ]) + signature = self.sign(signed_attrs, key_path) + self.assertLess(getallocatedblocks(), mem_start * 2) + start = time() + ci = ContentInfo(( + ("contentType", ContentType(id_signedData)), + ("content", Any((SignedData(( + ("version", CMSVersion("v3")), + ("digestAlgorithms", DigestAlgorithmIdentifiers([ai_sha512])), + ("encapContentInfo", eci), + ("certificates", CertificateSet([ + CertificateChoices(("certificate", cert)), + ])), + ("signerInfos", SignerInfos([SignerInfo(( + ("version", CMSVersion("v3")), + ("sid", SignerIdentifier(("subjectKeyIdentifier", skid))), + ("digestAlgorithm", DigestAlgorithmIdentifier(ai_sha512)), + ("signedAttrs", signed_attrs), + ("signatureAlgorithm", SignatureAlgorithmIdentifier(( + ("algorithm", id_ecdsa_with_SHA512), + ))), + ("signature", SignatureValue(signature)), + ))])), + ))))), + )) + _, state = ci.encode1st() + print("2pass state size", getsizeof(state)) + cms_path = self.tmpfile() + with io_open(cms_path, "wb") as fd: + ci.encode2nd(fd.write, iter(state)) + print("CMS written", time() - start) + self.assertLess(getallocatedblocks(), mem_start * 2) + self.verify(cert_path, cms_path) diff --git a/tests/test_crl.py b/tests/test_crl.py index 157844d479c46bb1bd93f05c7120ff81ce549e9c193e6e6448fed0470407189a..0f6c371484bc7295d46c5576aad15b7cf0e215337392e35b08178c812a64dfa4 100644 --- a/tests/test_crl.py +++ b/tests/test_crl.py @@ -17,7 +17,9 @@ # . """CRL related schemas, just to test the performance with them """ +from io import BytesIO from os.path import exists +from sys import getsizeof from time import time from unittest import skipIf from unittest import TestCase @@ -76,7 +78,7 @@ @skipIf(not exists(CRL_PATH), "CACert's revoke.crl not found") class TestCACert(TestCase): - def test_cer(self): + def test_cer_and_2pass(self): with open(CRL_PATH, "rb") as fd: raw = fd.read() print("DER read") @@ -84,16 +86,23 @@ start = time() crl1 = CertificateList().decod(raw) print("DER decoded", time() - start) start = time() + der_raw = crl1.encode() + print("DER encoded", time() - start) + self.assertSequenceEqual(der_raw, raw) + buf = BytesIO() + start = time() + _, state = crl1.encode1st() + print("1st pass state size", getsizeof(state)) + crl1.encode2nd(buf.write, iter(state)) + print("DER 2pass encoded", time() - start) + self.assertSequenceEqual(buf.getvalue(), raw) + start = time() cer_raw = encode_cer(crl1) print("CER encoded", time() - start) start = time() crl2 = CertificateList().decod(cer_raw, ctx={"bered": True}) print("CER decoded", time() - start) self.assertEqual(crl2, crl1) - start = time() - der_raw = crl2.encode() - print("DER encoded", time() - start) - self.assertSequenceEqual(der_raw, raw) @skipIf(PY2, "Py27 mmap does not implement buffer protocol") def test_mmaped(self): diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index 4c7d35bc45ff6d8c3be876c54fa48d027e005bd800c152dc2ea5efca4ee37a30..fa16f149098d577965809a2aa823f8ace63a96eb71e5e26c8f9e1fae1a3c1c4e 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -20,6 +20,7 @@ from copy import deepcopy from datetime import datetime from datetime import timedelta from importlib import import_module +from io import BytesIO from operator import attrgetter from os import environ from os import urandom @@ -75,6 +76,7 @@ from pyderasn import BoundsError from pyderasn import Choice from pyderasn import DecodeError from pyderasn import DecodePathDefBy +from pyderasn import encode2pass from pyderasn import encode_cer from pyderasn import Enumerated from pyderasn import EOC @@ -422,6 +424,8 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) repr(err.exception) obj = Boolean(value) self.assertTrue(obj.ready) @@ -506,6 +510,8 @@ def test_stripped(self, value, tag_impl): obj = Boolean(value, impl=tag_impl) with self.assertRaises(NotEnoughData): obj.decode(obj.encode()[:-1]) + with self.assertRaises(NotEnoughData): + obj.decode(encode2pass(obj)[:-1]) @given( booleans(), @@ -515,6 +521,8 @@ def test_stripped_expl(self, value, tag_expl): obj = Boolean(value, expl=tag_expl) with self.assertRaises(NotEnoughData): obj.decode(obj.encode()[:-1]) + with self.assertRaises(NotEnoughData): + obj.decode(encode2pass(obj)[:-1]) @given( integers(min_value=31), @@ -603,6 +611,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -863,6 +872,8 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) repr(err.exception) obj = Integer(value) self.assertTrue(obj.ready) @@ -925,6 +936,10 @@ Integer(bounds=(values[1], values[2])).decode( Integer(values[0]).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + Integer(bounds=(values[1], values[2])).decode( + encode2pass(Integer(values[0])) + ) with self.assertRaises(BoundsError) as err: Integer(value=values[2], bounds=(values[0], values[1])) repr(err.exception) @@ -933,6 +948,10 @@ Integer(bounds=(values[0], values[1])).decode( Integer(values[2]).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + Integer(bounds=(values[0], values[1])).decode( + encode2pass(Integer(values[2])) + ) @given(data_strategy()) def test_call(self, d): @@ -1124,6 +1143,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -1362,6 +1382,8 @@ pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = BitString(value) self.assertTrue(obj.ready) repr(obj) @@ -1540,6 +1562,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -1965,6 +1988,8 @@ pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = OctetString(value) self.assertTrue(obj.ready) repr(obj) @@ -2011,6 +2036,10 @@ OctetString(bounds=(bound_min, bound_max)).decode( OctetString(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + OctetString(bounds=(bound_min, bound_max)).decode( + encode2pass(OctetString(value)) + ) value = d.draw(binary(min_size=bound_max + 1)) with self.assertRaises(BoundsError) as err: OctetString(value=value, bounds=(bound_min, bound_max)) @@ -2020,6 +2049,10 @@ OctetString(bounds=(bound_min, bound_max)).decode( OctetString(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + OctetString(bounds=(bound_min, bound_max)).decode( + encode2pass(OctetString(value)) + ) @given(data_strategy()) def test_call(self, d): @@ -2196,6 +2229,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -2566,6 +2600,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -2695,6 +2730,8 @@ pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = ObjectIdentifier(value) self.assertTrue(obj.ready) self.assertFalse(obj.ber_encoded) @@ -2932,6 +2969,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.assertSequenceEqual(encode_cer(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -3327,6 +3365,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) repr(obj_expled) @@ -3444,6 +3483,8 @@ text_type(obj) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) value = d.draw(text(alphabet=self.text_alphabet())) obj = self.base_klass(value) self.assertTrue(obj.ready) @@ -3493,6 +3534,10 @@ self.base_klass(bounds=(bound_min, bound_max)).decode( self.base_klass(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + self.base_klass(bounds=(bound_min, bound_max)).decode( + encode2pass(self.base_klass(value)) + ) value = d.draw(text(alphabet=self.text_alphabet(), min_size=bound_max + 1)) with self.assertRaises(BoundsError) as err: self.base_klass(value=value, bounds=(bound_min, bound_max)) @@ -3502,6 +3547,10 @@ self.base_klass(bounds=(bound_min, bound_max)).decode( self.base_klass(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + self.base_klass(bounds=(bound_min, bound_max)).decode( + encode2pass(self.base_klass(value)) + ) @given(data_strategy()) def test_call(self, d): @@ -3677,6 +3726,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) repr(obj_expled) @@ -4033,6 +4083,8 @@ pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) value = d.draw(datetimes( min_value=self.min_datetime, max_value=self.max_datetime, @@ -4191,6 +4243,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) self.additional_symmetric_check(value, obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) @@ -5000,6 +5053,8 @@ pprint(obj, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj = Any(value) self.assertTrue(obj.ready) repr(obj) @@ -5148,6 +5203,7 @@ self.assertFalse(obj.expled) tag_class, _, tag_num = tag_decode(tag_strip(value)[0]) self.assertEqual(obj.tag_order, (tag_class, tag_num)) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) tag_class, _, tag_num = tag_decode(tag_expl) @@ -5384,6 +5440,8 @@ self.assertIsNone(obj["whatever"]) with self.assertRaises(ObjNotReady) as err: obj.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(obj) obj["whatever"] = Boolean() self.assertFalse(obj.ready) repr(obj) @@ -5532,6 +5590,7 @@ pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) self.assertEqual(obj.tag_order, obj.value.tag_order) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_expled = obj(value, expl=tag_expl) self.assertTrue(obj_expled.expled) tag_class, _, tag_num = tag_decode(tag_expl) @@ -5879,6 +5938,8 @@ pprint(seq, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: seq.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(seq) for name, value in non_ready.items(): seq[name] = Boolean(value) self.assertTrue(seq.ready) @@ -6067,6 +6128,7 @@ list(seq.pps()) pprint(seq, big_blobs=True, with_decode_path=True) self.assertTrue(seq.ready) seq_encoded = seq.encode() + self.assertEqual(encode2pass(seq), seq_encoded) seq_encoded_cer = encode_cer(seq) self.assertNotEqual(seq_encoded_cer, seq_encoded) self.assertSequenceEqual( @@ -6155,6 +6217,7 @@ def test_symmetric_with_seq(self, d): seq, expect_outers = d.draw(sequences_strategy(seq_klass=self.base_klass)) self.assertTrue(seq.ready) seq_encoded = seq.encode() + self.assertEqual(encode2pass(seq), seq_encoded) seq_decoded, tail = seq.decode(seq_encoded) self.assertEqual(tail, b"") self.assertTrue(seq.ready) @@ -6527,6 +6590,8 @@ pprint(seqof, big_blobs=True, with_decode_path=True) with self.assertRaises(ObjNotReady) as err: seqof.encode() repr(err.exception) + with self.assertRaises(ObjNotReady) as err: + encode2pass(seqof) for i, value in enumerate(values): self.assertEqual(seqof[i], value) if not seqof[i].ready: @@ -6570,6 +6635,10 @@ SeqOf(bounds=(bound_min, bound_max)).decode( SeqOf(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + SeqOf(bounds=(bound_min, bound_max)).decode( + encode2pass(SeqOf(value)) + ) value = [Boolean(True)] * d.draw(integers( min_value=bound_max + 1, max_value=bound_max + 10, @@ -6582,6 +6651,10 @@ SeqOf(bounds=(bound_min, bound_max)).decode( SeqOf(value).encode() ) repr(err.exception) + with assertRaisesRegex(self, DecodeError, "bounds") as err: + SeqOf(bounds=(bound_min, bound_max)).decode( + encode2pass(SeqOf(value)) + ) @given(integers(min_value=1, max_value=10)) def test_out_of_bounds(self, bound_max): @@ -6788,6 +6861,7 @@ list(obj.pps()) pprint(obj, big_blobs=True, with_decode_path=True) self.assertFalse(obj.expled) obj_encoded = obj.encode() + self.assertEqual(encode2pass(obj), obj_encoded) obj_encoded_cer = encode_cer(obj) self.assertNotEqual(obj_encoded_cer, obj_encoded) self.assertSequenceEqual( @@ -6972,6 +7046,26 @@ seqof.encode() self.assertFalse(seqof.ready) register_class(SeqOf) pickle_dumps(seqof) + + def test_iterator_2pass(self): + class SeqOf(SequenceOf): + schema = Integer() + bounds = (1, float("+inf")) + def gen(): + for i in six_xrange(10): + yield Integer(i) + seqof = SeqOf(gen()) + self.assertTrue(seqof.ready) + _, state = seqof.encode1st() + self.assertFalse(seqof.ready) + seqof = seqof(gen()) + self.assertTrue(seqof.ready) + buf = BytesIO() + seqof.encode2nd(buf.write, iter(state)) + self.assertSequenceEqual( + [int(i) for i in seqof.decod(buf.getvalue())], + list(gen()), + ) def test_non_ready_bound_min(self): class SeqOf(SequenceOf):