doc/news.rst | 2 ++ pyderasn.py | 86 +++++++++++++++++++++++++++++++++++++++++++++-------- tests/test_pyderasn.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/doc/news.rst b/doc/news.rst index 1dc6c6e449d8019fa391e9cb7af6ca4aab87883ee2b432b1fd1a1f423cbe290a..4b7baf3b2594a05dd2f8fa7b1ef9d00f0248aa612ed8a4c9e0a2bead3ba4d514 100644 --- a/doc/news.rst +++ b/doc/news.rst @@ -11,6 +11,8 @@ * Fixed possibly invalid SET DER encoding where objects were not sorted by tag, but by encoded representation * ``Any`` does not allow empty data value now. Now it checks if it has valid ASN.1 tag +* ``SetOf`` is not treated as ready, if no value was set and minimum + bounds are greater than zero * ``Any`` allows an ordinary ``Obj`` storing, without its forceful encoded representation storage * Initial support for so called ``evgen_mode``: event generation mode, diff --git a/pyderasn.py b/pyderasn.py index 8ba6982802ad51a0ce21251e60c79d6a942554208a1cbf73fe9a714a0afe6389..9cd69efddd90b883b14f233c61e59d5b1c61cb5318215a30e40d254995b74ae6 100755 --- a/pyderasn.py +++ b/pyderasn.py @@ -894,6 +894,7 @@ NAMEDTUPLE_KWARGS = {} if version_info < (3, 6) else {"module": __name__} SET01 = frozenset("01") DECIMALS = frozenset(digits) DECIMAL_SIGNS = ".," +NEXT_ATTR_NAME = "next" if PY2 else "__next__" def file_mmaped(fd): @@ -6226,9 +6227,21 @@ >>> ints[1] = Integer(345) >>> ints Ints SEQUENCE OF[INTEGER 123, INTEGER 345] - Also you can initialize sequence with preinitialized values: + You can initialize sequence with preinitialized values: >>> ints = Ints([Integer(123), Integer(234)]) + + Also you can use iterator as a value: + + >>> ints = Ints(iter(Integer(i) for i in range(1000000))) + + And it won't be iterated until encoding process. Pay attention that + bounds and required schema checks are done only during the encoding + process in that case! After encode was called, then value is zeroed + back to empty list and you have to set it again. That mode is useful + mainly with CER encoding mode, where all objects from the iterable + will be streamed to the buffer, without copying all of them to + memory first. """ __slots__ = ("spec", "_bound_min", "_bound_max") tag_default = tag_encode(form=TagFormConstructed, num=16) @@ -6272,21 +6285,31 @@ if value is None: self._value = copy(default_obj._value) def _value_sanitize(self, value): + iterator = False if issubclass(value.__class__, SequenceOf): value = value._value + elif hasattr(value, NEXT_ATTR_NAME): + iterator = True + value = value elif hasattr(value, "__iter__"): value = list(value) else: - raise InvalidValueType((self.__class__, iter)) - if not self._bound_min <= len(value) <= self._bound_max: - raise BoundsError(self._bound_min, len(value), self._bound_max) - for v in value: - if not isinstance(v, self.spec.__class__): - raise InvalidValueType((self.spec.__class__,)) + raise InvalidValueType((self.__class__, iter, "iterator")) + if not iterator: + if not self._bound_min <= len(value) <= self._bound_max: + raise BoundsError(self._bound_min, len(value), self._bound_max) + class_expected = self.spec.__class__ + for v in value: + if not isinstance(v, class_expected): + raise InvalidValueType((class_expected,)) return value @property def ready(self): + if hasattr(self._value, NEXT_ATTR_NAME): + return True + if self._bound_min > 0 and len(self._value) == 0: + return False return all(v.ready for v in self._value) @property @@ -6296,6 +6319,8 @@ return True return any(v.bered for v in self._value) def __getstate__(self): + if hasattr(self._value, NEXT_ATTR_NAME): + raise ValueError("can not pickle SequenceOf with iterator") return SequenceOfState( __version__, self.tag, @@ -6371,11 +6396,9 @@ ) self._value.append(value) def __iter__(self): - self._assert_ready() return iter(self._value) def __len__(self): - self._assert_ready() return len(self._value) def __setitem__(self, key, value): @@ -6390,13 +6413,42 @@ def _values_for_encoding(self): return iter(self._value) def _encode(self): - v = b"".join(v.encode() for v in self._values_for_encoding()) - return b"".join((self.tag, len_encode(len(v)), v)) + iterator = hasattr(self._value, NEXT_ATTR_NAME) + if iterator: + values = [] + values_append = values.append + class_expected = self.spec.__class__ + values_for_encoding = self._values_for_encoding() + self._value = [] + for v in values_for_encoding: + if not isinstance(v, class_expected): + raise InvalidValueType((class_expected,)) + values_append(v.encode()) + if not self._bound_min <= len(values) <= self._bound_max: + raise BoundsError(self._bound_min, len(values), self._bound_max) + value = b"".join(values) + else: + value = b"".join(v.encode() for v in self._values_for_encoding()) + return b"".join((self.tag, len_encode(len(value)), value)) def _encode_cer(self, writer): write_full(writer, self.tag + LENINDEF) - for v in self._values_for_encoding(): - v.encode_cer(writer) + iterator = hasattr(self._value, NEXT_ATTR_NAME) + if iterator: + class_expected = self.spec.__class__ + values_count = 0 + values_for_encoding = self._values_for_encoding() + self._value = [] + for v in values_for_encoding: + if not isinstance(v, class_expected): + raise InvalidValueType((class_expected,)) + v.encode_cer(writer) + values_count += 1 + if not self._bound_min <= values_count <= self._bound_max: + raise BoundsError(self._bound_min, values_count, self._bound_max) + else: + for v in self._values_for_encoding(): + v.encode_cer(writer) write_full(writer, EOC) def _decode( @@ -6589,6 +6641,14 @@ """ __slots__ = () tag_default = tag_encode(form=TagFormConstructed, num=17) asn1_type_name = "SET OF" + + def _value_sanitize(self, value): + value = super(SetOf, self)._value_sanitize(value) + if hasattr(value, NEXT_ATTR_NAME): + raise ValueError( + "SetOf does not support iterator values, as no sense in them" + ) + return value def _encode(self): v = b"".join(sorted(v.encode() for v in self._values_for_encoding())) diff --git a/tests/test_pyderasn.py b/tests/test_pyderasn.py index 10214031517e69c816d34d2f2995286556f8873478fafc8880d4b7477cb009ea..c203e508a911932276b9e0fec5d0f9ed3933061ecaed113d15a4e91ba243a866 100644 --- a/tests/test_pyderasn.py +++ b/tests/test_pyderasn.py @@ -60,6 +60,7 @@ from six import iterbytes from six import PY2 from six import text_type from six import unichr as six_unichr +from six.moves import xrange as six_xrange from six.moves.cPickle import dumps as pickle_dumps from six.moves.cPickle import HIGHEST_PROTOCOL as pickle_proto from six.moves.cPickle import loads as pickle_loads @@ -6898,6 +6899,57 @@ def _test_symmetric_compare_objs(self, obj1, obj2): self.assertEqual(obj1, obj2) self.assertSequenceEqual(list(obj1), list(obj2)) + + def test_iterator_pickling(self): + class SeqOf(SequenceOf): + schema = Integer() + register_class(SeqOf) + seqof = SeqOf() + pickle_dumps(seqof) + seqof = seqof(iter(six_xrange(10))) + with assertRaisesRegex(self, ValueError, "iterator"): + pickle_dumps(seqof) + + def test_iterator_bounds(self): + class SeqOf(SequenceOf): + schema = Integer() + bounds = (10, 20) + seqof = None + def gen(n): + for i in six_xrange(n): + yield Integer(i) + for n in (9, 21): + seqof = SeqOf(gen(n)) + self.assertTrue(seqof.ready) + with self.assertRaises(BoundsError): + seqof.encode() + self.assertFalse(seqof.ready) + seqof = seqof(gen(n)) + self.assertTrue(seqof.ready) + with self.assertRaises(BoundsError): + encode_cer(seqof) + self.assertFalse(seqof.ready) + + def test_iterator_twice(self): + class SeqOf(SequenceOf): + schema = Integer() + bounds = (1, float("+inf")) + def gen(): + for i in six_xrange(10): + yield Integer(i) + seqof = SeqOf(gen()) + self.assertTrue(seqof.ready) + seqof.encode() + self.assertFalse(seqof.ready) + register_class(SeqOf) + pickle_dumps(seqof) + + def test_non_ready_bound_min(self): + class SeqOf(SequenceOf): + schema = Integer() + bounds = (1, float("+inf")) + seqof = SeqOf() + self.assertFalse(seqof.ready) class TestSetOf(SeqOfMixing, CommonMixin, TestCase):