Source code for rubymarshal.writer
# -*- coding: utf-8 -*-
from __future__ import division, unicode_literals
import re
import io
import math
from rubymarshal.classes import Symbol, UsrMarshal
from rubymarshal.constants import TYPE_BIGNUM, TYPE_STRING, TYPE_REGEXP, TYPE_ARRAY, TYPE_HASH, TYPE_USRMARSHAL, TYPE_NIL, TYPE_TRUE, TYPE_FALSE, \
TYPE_IVAR, TYPE_LINK, TYPE_SYMLINK, TYPE_SYMBOL, TYPE_FIXNUM
from rubymarshal.constants import TYPE_FLOAT
from rubymarshal.utils import write_ushort, write_sbyte, write_ubyte, integer_types, binary_type, text_type
__author__ = 'Matthieu Gallet'
re_class = re.compile('').__class__
simple_float_re = re.compile(r'^\d+\.\d*0+$')
[docs]class Writer(object):
def __init__(self, fd):
self.symbols = {}
self.objects = {}
self.fd = fd
[docs] def write(self, obj):
if obj is None:
self.fd.write(TYPE_NIL)
elif obj is False:
self.fd.write(TYPE_FALSE)
elif obj is True:
self.fd.write(TYPE_TRUE)
elif isinstance(obj, int) or isinstance(obj, integer_types[-1]):
if obj.bit_length() <= 5 * 8:
self.fd.write(TYPE_FIXNUM)
# noinspection PyTypeChecker
self.write_long(obj)
else:
self.fd.write(TYPE_BIGNUM)
if obj < 0:
self.fd.write(b'-')
else:
self.fd.write(b'+')
obj = abs(obj)
size = int(math.ceil(obj.bit_length() / 16.))
self.write_long(size)
for i in range(size):
self.write_short(obj % 65536)
obj //= 65536
elif isinstance(obj, Symbol):
if obj.name in self.symbols:
self.fd.write(TYPE_SYMLINK)
self.write_long(self.symbols[obj.name])
else:
self.fd.write(TYPE_SYMBOL)
symbol_index = len(self.symbols)
self.symbols[obj.name] = symbol_index
encoded = obj.name.encode('utf-8')
self.write_long(len(encoded))
self.fd.write(encoded)
elif isinstance(obj, list):
if self.must_write(obj):
self.fd.write(TYPE_ARRAY)
self.write_long(len(obj))
for x in obj:
self.write(x)
elif isinstance(obj, dict):
if self.must_write(obj):
self.fd.write(TYPE_HASH)
self.write_long(len(obj))
for key, value in obj.items():
self.write(key)
self.write(value)
elif isinstance(obj, binary_type):
self.fd.write(TYPE_IVAR)
self.fd.write(TYPE_STRING)
self.write_long(len(obj))
self.fd.write(obj)
self.write_long(1)
self.write(Symbol('E'))
self.write(False)
elif isinstance(obj, text_type):
obj = obj.encode('utf-8')
self.fd.write(TYPE_IVAR)
self.fd.write(TYPE_STRING)
self.write_long(len(obj))
self.fd.write(obj)
self.write_long(1)
self.write(Symbol('E'))
self.write(True)
elif isinstance(obj, float):
obj = '%.20g' % obj
if simple_float_re.match(obj):
while obj.endswith('0'):
obj = obj[:-1]
obj = obj.encode('utf-8')
self.fd.write(TYPE_FLOAT)
self.write_long(len(obj))
self.fd.write(obj)
elif isinstance(obj, re_class):
flags = 0
if obj.flags & re.IGNORECASE:
flags += 1
if obj.flags & re.DOTALL:
flags += 4
self.fd.write(TYPE_IVAR)
self.fd.write(TYPE_REGEXP)
pattern = obj.pattern.encode('utf-8')
self.write_long(len(pattern))
self.fd.write(pattern)
write_ubyte(self.fd, flags)
self.write_long(1)
self.write(Symbol('E'))
self.write(False)
elif isinstance(obj, UsrMarshal):
if self.must_write(obj):
self.fd.write(TYPE_USRMARSHAL)
self.write(Symbol(obj.cls))
self.write(obj.values)
else:
raise ValueError(obj)
[docs] def write_short(self, obj):
write_ushort(self.fd, obj)
[docs] def write_long(self, obj):
if obj == 0:
self.fd.write(b'\0')
elif 0 < obj < 123:
write_sbyte(self.fd, obj + 5)
elif -124 < obj < 0:
write_sbyte(self.fd, obj - 5)
else:
size = int(math.ceil(obj.bit_length() / 8.))
if size > 5:
raise ValueError('%d too long for serialization' % obj)
original_obj = obj
factor = 256 ** size
if obj < 0 and obj == -factor:
size -= 1
obj += (factor / 256)
elif obj < 0:
obj += factor
sign = int(math.copysign(size, original_obj))
write_sbyte(self.fd, sign)
for i in range(size):
write_ubyte(self.fd, obj % 256)
obj //= 256
[docs] def must_write(self, obj):
if id(obj) in self.objects:
self.fd.write(TYPE_LINK)
self.write_long(self.objects[id(obj)])
return False
else:
link_index = len(self.objects)
self.objects[id(obj)] = link_index
return True
[docs]def write(fd, obj):
fd.write(b'\x04\x08')
writer = Writer(fd)
writer.write(obj)
[docs]def writes(obj):
fd = io.BytesIO()
write(fd, obj)
return fd.getvalue()