"""Provides functions for reading and writing (writing is WIP currently) Java objects
serialized or will be deserialized by ObjectOutputStream. This form of object
representation is a standard data interchange format in Java world.

javaobj module exposes an API familiar to users of the standard library marshal, pickle and json modules.

See: http://download.oracle.com/javase/6/docs/platform/serialization/spec/protocol.html
"""

# Got this 2012-02 from
# http://code.google.com/p/python-javaobj/source/checkout
# License unclear

import StringIO
import struct

if True:
    def log_debug(message, ident=0):
        pass
    def log_error(message, ident=0):
        pass
else:
    import logging
    _log = logging.getLogger(__name__)
    def log_debug(message, ident=0):
        _log.debug(" " * (ident * 2) + str(message))
    def log_error(message, ident=0):
        _log.error(" " * (ident * 2) + str(message))

__version__ = "$Revision: 20 $"

def load(file_object):
    """
    Deserializes Java primitive data and objects serialized by ObjectOutputStream
    from a file-like object.
    """
    marshaller = JavaObjectUnmarshaller(file_object)
    marshaller.add_transformer(DefaultObjectTransformer())
    return marshaller.readObject()


def loads(string):
    """
    Deserializes Java objects and primitive data serialized by ObjectOutputStream
    from a string.
    """
    f = StringIO.StringIO(string)
    marshaller = JavaObjectUnmarshaller(f)
    marshaller.add_transformer(DefaultObjectTransformer())
    return marshaller.readObject()


def dumps(object):
    """
    Serializes Java primitive data and objects unmarshaled by load(s) before into string.
    """
    marshaller = JavaObjectMarshaller()
    return marshaller.dump(object)


class JavaClass(object):
    def __init__(self):
        self.name = None
        self.serialVersionUID = None
        self.flags = None
        self.fields_names = []
        self.fields_types = []
        self.superclass = None

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return "[%s:0x%X]" % (self.name, self.serialVersionUID)


class JavaObject(object):

    def __init__(self):
        self.classdesc = None
        self.annotations = []

    def get_class(self):
        return self.classdesc

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        name = "UNKNOWN"
        if self.classdesc:
            name = self.classdesc.name
        return "<javaobj:%s>" % name

    def copy(self, new_object):
        new_object.classdesc = self.classdesc

        for name in self.classdesc.fields_names:
            new_object.__setattr__(name, getattr(self, name))

class JavaObjectConstants:

    STREAM_MAGIC = 0xaced
    STREAM_VERSION = 0x05

    TC_NULL = 0x70
    TC_REFERENCE = 0x71
    TC_CLASSDESC = 0x72
    TC_OBJECT = 0x73
    TC_STRING = 0x74
    TC_ARRAY = 0x75
    TC_CLASS = 0x76
    TC_BLOCKDATA = 0x77
    TC_ENDBLOCKDATA = 0x78
    TC_RESET = 0x79
    TC_BLOCKDATALONG = 0x7A
    TC_EXCEPTION = 0x7B
    TC_LONGSTRING = 0x7C
    TC_PROXYCLASSDESC = 0x7D
    TC_ENUM = 0x7E
    TC_MAX = 0x7E

    # classDescFlags
    SC_WRITE_METHOD = 0x01 # if SC_SERIALIZABLE
    SC_BLOCK_DATA = 0x08   # if SC_EXTERNALIZABLE
    SC_SERIALIZABLE = 0x02
    SC_EXTERNALIZABLE = 0x04
    SC_ENUM = 0x10

    # type definition chars (typecode)
    TYPE_BYTE = 'B'     # 0x42
    TYPE_CHAR = 'C'
    TYPE_DOUBLE = 'D'   # 0x44
    TYPE_FLOAT = 'F'    # 0x46
    TYPE_INTEGER = 'I'  # 0x49
    TYPE_LONG = 'J'     # 0x4A
    TYPE_SHORT = 'S'    # 0x53
    TYPE_BOOLEAN = 'Z'  # 0x5A
    TYPE_OBJECT = 'L'   # 0x4C
    TYPE_ARRAY = '['    # 0x5B

    # list of supported typecodes listed above
    TYPECODES_LIST = [
            # primitive types
            TYPE_BYTE,
            TYPE_CHAR,
            TYPE_DOUBLE,
            TYPE_FLOAT,
            TYPE_INTEGER,
            TYPE_LONG,
            TYPE_SHORT,
            TYPE_BOOLEAN,
            # object types
            TYPE_OBJECT,
            TYPE_ARRAY ]

    BASE_REFERENCE_IDX = 0x7E0000


class JavaObjectUnmarshaller(JavaObjectConstants):

    def __init__(self, stream=None):
        self.opmap = {
            self.TC_NULL: self.do_null,
            self.TC_CLASSDESC: self.do_classdesc,
            self.TC_OBJECT: self.do_object,
            self.TC_STRING: self.do_string,
            self.TC_ARRAY: self.do_array,
            self.TC_CLASS: self.do_class,
            self.TC_BLOCKDATA: self.do_blockdata,
            self.TC_REFERENCE: self.do_reference,
            self.TC_ENUM: self.do_enum,
            self.TC_ENDBLOCKDATA: self.do_null, # note that we are reusing of do_null
        }
        self.current_object = None
        self.reference_counter = 0
        self.references = []
        self.object_stream = stream
        self._readStreamHeader()
        self.object_transformers = []

    def readObject(self):
        try:
            opcode, res = self._read_and_exec_opcode(ident=0)    # TODO: add expects

            position_bak = self.object_stream.tell()
            the_rest = self.object_stream.read()
            if len(the_rest):
                log_error("Warning!!!!: Stream still has %s bytes left. Enable debug mode of logging to see the hexdump." % len(the_rest))
                log_debug(self._create_hexdump(the_rest))
            else:
                log_debug("Java Object unmarshalled succesfully!")
            self.object_stream.seek(position_bak)

            return res
        except Exception, e:
            self._oops_dump_state()
            raise

    def add_transformer(self, transformer):
        self.object_transformers.append(transformer)

    def _readStreamHeader(self):
        (magic, version) = self._readStruct(">HH")
        if magic != self.STREAM_MAGIC or version != self.STREAM_VERSION:
            raise IOError("The stream is not java serialized object. Invalid stream header: %04X%04X" % (magic, version))

    def _read_and_exec_opcode(self, ident=0, expect=None):
        (opid, ) = self._readStruct(">B")
        log_debug("OpCode: 0x%X" % opid, ident)
        if expect and opid not in expect:
            raise IOError("Unexpected opcode 0x%X" % opid)
        handler = self.opmap.get(opid)
        if not handler:
            raise RuntimeError("Unknown OpCode in the stream: 0x%x" % opid)
        return (opid, handler(ident=ident))

    def _readStruct(self, unpack):
        length = struct.calcsize(unpack)
        ba = self.object_stream.read(length)
        if len(ba) != length:
            raise RuntimeError("Stream has been ended unexpectedly while unmarshaling.")
        return struct.unpack(unpack, ba)

    def _readString(self):
        (length, ) = self._readStruct(">H")
        ba = self.object_stream.read(length)
        return ba

    def do_classdesc(self, parent=None, ident=0):
        # TC_CLASSDESC className serialVersionUID newHandle classDescInfo
        # classDescInfo:
        #   classDescFlags fields classAnnotation superClassDesc
        # classDescFlags:
        #   (byte)                  // Defined in Terminal Symbols and Constants
        # fields:
        #   (short)<count>  fieldDesc[count]

        # fieldDesc:
        #   primitiveDesc
        #   objectDesc
        # primitiveDesc:
        #   prim_typecode fieldName
        # objectDesc:
        #   obj_typecode fieldName className1
        clazz = JavaClass()
        log_debug("[classdesc]", ident)
        ba = self._readString()
        clazz.name = ba
        log_debug("Class name: %s" % ba, ident)
        (serialVersionUID, newHandle, classDescFlags) = self._readStruct(">LLB")
        clazz.serialVersionUID = serialVersionUID
        clazz.flags = classDescFlags

        self._add_reference(clazz)

        log_debug("Serial: 0x%X newHandle: 0x%X. classDescFlags: 0x%X" % (serialVersionUID, newHandle, classDescFlags), ident)
        (length, ) = self._readStruct(">H")
        log_debug("Fields num: 0x%X" % length, ident)

        clazz.fields_names = []
        clazz.fields_types = []
        for fieldId in range(length):
            (typecode, ) = self._readStruct(">B")
            field_name = self._readString()
            field_type = None
            field_type = self._convert_char_to_type(typecode)

            if field_type == self.TYPE_ARRAY:
                opcode, field_type = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_STRING, self.TC_REFERENCE])
                assert type(field_type) is str
#                if field_type is not None:
#                    field_type = "array of " + field_type
#                else:
#                    field_type = "array of None"
            elif field_type == self.TYPE_OBJECT:
                opcode, field_type = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_STRING, self.TC_REFERENCE])
                assert type(field_type) is str

            log_debug("FieldName: 0x%X" % typecode + " " + str(field_name) + " " + str(field_type), ident)
            assert field_name is not None
            assert field_type is not None

            clazz.fields_names.append(field_name)
            clazz.fields_types.append(field_type)
        if parent:
            parent.__fields = clazz.fields_names
            parent.__types = clazz.fields_types
        # classAnnotation
        (opid, ) = self._readStruct(">B")
        log_debug("OpCode: 0x%X" % opid, ident)
        if opid != self.TC_ENDBLOCKDATA:
            raise NotImplementedError("classAnnotation isn't implemented yet")
        # superClassDesc
        opcode, superclassdesc = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_CLASSDESC, self.TC_NULL, self.TC_REFERENCE])
        log_debug(str(superclassdesc), ident)
        clazz.superclass = superclassdesc

        return clazz

    def do_blockdata(self, parent=None, ident=0):
        # TC_BLOCKDATA (unsigned byte)<size> (byte)[size]
        log_debug("[blockdata]", ident)
        (length, ) = self._readStruct(">B")
        ba = self.object_stream.read(length)
        return ba

    def do_class(self, parent=None, ident=0):
        # TC_CLASS classDesc newHandle
        log_debug("[class]", ident)

        # TODO: what to do with "(ClassDesc)prevObject". (see 3rd line for classDesc:)
        opcode, classdesc = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_CLASSDESC, self.TC_PROXYCLASSDESC, self.TC_NULL, self.TC_REFERENCE])
        log_debug("Classdesc: %s" % classdesc, ident)
        self._add_reference(classdesc)
        return classdesc

    def do_object(self, parent=None, ident=0):
        # TC_OBJECT classDesc newHandle classdata[]  // data for each class
        java_object = JavaObject()
        log_debug("[object]", ident)
        log_debug("java_object.annotations just after instantination: " + str(java_object.annotations), ident)

        # TODO: what to do with "(ClassDesc)prevObject". (see 3rd line for classDesc:)
        opcode, classdesc = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_CLASSDESC, self.TC_PROXYCLASSDESC, self.TC_NULL, self.TC_REFERENCE])
        # self.TC_REFERENCE hasn't shown in spec, but actually is here

        self._add_reference(java_object)

        # classdata[]

        # Store classdesc of this object
        java_object.classdesc = classdesc

        if classdesc.flags & self.SC_EXTERNALIZABLE and not classdesc.flags & self.SC_BLOCK_DATA:
            raise NotImplementedError("externalContents isn't implemented yet") # TODO:

        if classdesc.flags & self.SC_SERIALIZABLE:
            # create megalist
            tempclass = classdesc
            megalist = []
            megatypes = []
            while tempclass:
                log_debug(">>> " + str(tempclass.fields_names) + " " + str(tempclass), ident)
                log_debug(">>> " + str(tempclass.fields_types), ident)
                fieldscopy = tempclass.fields_names[:]
                fieldscopy.extend(megalist)
                megalist = fieldscopy

                fieldscopy = tempclass.fields_types[:]
                fieldscopy.extend(megatypes)
                megatypes = fieldscopy

                tempclass = tempclass.superclass

            log_debug("Values count: %s" % str(len(megalist)), ident)
            log_debug("Prepared list of values: %s" % str(megalist), ident)
            log_debug("Prepared list of types: %s" % str(megatypes), ident)

            for field_name, field_type in zip(megalist, megatypes):
                res = self._read_value(field_type, ident, name=field_name)
                java_object.__setattr__(field_name, res)

        if classdesc.flags & self.SC_SERIALIZABLE and classdesc.flags & self.SC_WRITE_METHOD or classdesc.flags & self.SC_EXTERNALIZABLE and classdesc.flags & self.SC_BLOCK_DATA:
            # objectAnnotation
            log_debug("java_object.annotations before: " + str(java_object.annotations), ident)
            while opcode != self.TC_ENDBLOCKDATA:
                opcode, obj = self._read_and_exec_opcode(ident=ident+1) # , expect=[self.TC_ENDBLOCKDATA, self.TC_BLOCKDATA, self.TC_OBJECT, self.TC_NULL, self.TC_REFERENCE])
                if opcode != self.TC_ENDBLOCKDATA:
                    java_object.annotations.append(obj)
                log_debug("objectAnnotation value: " + str(obj), ident)
            log_debug("java_object.annotations after: " + str(java_object.annotations), ident)

        # Transform object
        for transformer in self.object_transformers:
            tmp_object = transformer.transform(java_object)
            if tmp_object != java_object:
                java_object = tmp_object
                break

        log_debug(">>> java_object: " + str(java_object), ident)
        return java_object

    def do_string(self, parent=None, ident=0):
        log_debug("[string]", ident)
        ba = self._readString()
        self._add_reference(str(ba))
        return str(ba)

    def do_array(self, parent=None, ident=0):
        # TC_ARRAY classDesc newHandle (int)<size> values[size]
        log_debug("[array]", ident)
        opcode, classdesc = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_CLASSDESC, self.TC_PROXYCLASSDESC, self.TC_NULL, self.TC_REFERENCE])

        array = []

        self._add_reference(array)

        (size, ) = self._readStruct(">i")
        log_debug("size: " + str(size), ident)

        type_char = classdesc.name[0]
        assert type_char == self.TYPE_ARRAY
        type_char = classdesc.name[1]

        if type_char == self.TYPE_OBJECT or type_char == self.TYPE_ARRAY:
            for i in range(size):
                opcode, res = self._read_and_exec_opcode(ident=ident+1)
                log_debug("Object value: %s" % str(res), ident)
                array.append(res)
        else:
            for i in range(size):
                res = self._read_value(type_char, ident)
                log_debug("Native value: %s" % str(res), ident)
                array.append(res)

        return array

    def do_reference(self, parent=None, ident=0):
        (handle, ) = self._readStruct(">L")
        log_debug("## Reference handle: 0x%x" % (handle), ident)
        return self.references[handle - self.BASE_REFERENCE_IDX]

    def do_null(self, parent=None, ident=0):
        return None

    def do_enum(self, parent=None, ident=0):
        # TC_ENUM classDesc newHandle enumConstantName
        enum = JavaObject()
        opcode, classdesc = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_CLASSDESC, self.TC_PROXYCLASSDESC, self.TC_NULL, self.TC_REFERENCE])
        self._add_reference(enum)
        opcode, enumConstantName = self._read_and_exec_opcode(ident=ident+1, expect=[self.TC_STRING, self.TC_REFERENCE])
        return enumConstantName

    def _create_hexdump(self, src, length=16):
        FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
        result = []
        for i in xrange(0, len(src), length):
            s = src[i:i+length]
            hexa = ' '.join(["%02X"%ord(x) for x in s])
            printable = s.translate(FILTER)
            result.append("%04X   %-*s  %s\n" % (i, length*3, hexa, printable))
        return ''.join(result)

    def _read_value(self, field_type, ident, name = ""):
        if len(field_type) > 1:
            field_type = field_type[0]  # We don't need details for arrays and objects

        if field_type == self.TYPE_BOOLEAN:
            (val, ) = self._readStruct(">B")
            res = bool(val)
        elif field_type == self.TYPE_BYTE:
            (res, ) = self._readStruct(">b")
        elif field_type == self.TYPE_SHORT:
            (res, ) = self._readStruct(">h")
        elif field_type == self.TYPE_INTEGER:
            (res, ) = self._readStruct(">i")
        elif field_type == self.TYPE_LONG:
            (res, ) = self._readStruct(">q")
        elif field_type == self.TYPE_FLOAT:
            (res, ) = self._readStruct(">f")
        elif field_type == self.TYPE_DOUBLE:
            (res, ) = self._readStruct(">d")
        elif field_type == self.TYPE_OBJECT or field_type == self.TYPE_ARRAY:
            opcode, res = self._read_and_exec_opcode(ident=ident+1)
        else:
            raise RuntimeError("Unknown typecode: %s" % field_type)
        log_debug("* %s %s: " % (field_type, name) + str(res), ident)
        return res

    def _convert_char_to_type(self, type_char):
        typecode = type_char
        if type(type_char) is int:
            typecode = chr(type_char)

        if typecode in self.TYPECODES_LIST:
            return typecode
        else:
            raise RuntimeError("Typecode %s (%s) isn't supported." % (type_char, typecode))

    def _add_reference(self, obj):
        self.references.append(obj)

    def _oops_dump_state(self):
        log_error("==Oops state dump" + "=" * (30 - 17))
        log_error("References: %s" % str(self.references))
        log_error("Stream seeking back at -16 byte (2nd line is an actual position!):")
        self.object_stream.seek(-16, mode=1)
        the_rest = self.object_stream.read()
        if len(the_rest):
            log_error("Warning!!!!: Stream still has %s bytes left." % len(the_rest))
            log_error(self._create_hexdump(the_rest))
        log_error("=" * 30)


class JavaObjectMarshaller(JavaObjectConstants):

    def __init__(self, stream=None):
        self.object_stream = stream

    def dump(self, obj):

        self.object_obj = obj
        self.object_stream = StringIO.StringIO()
        self._writeStreamHeader()
        self.writeObject(obj)
        return self.object_stream.getvalue()

    def _writeStreamHeader(self):
        self._writeStruct(">HH", 4, (self.STREAM_MAGIC, self.STREAM_VERSION))

    def writeObject(self, obj):
        log_debug("Writing object of type " + str(type(obj)))
        if type(obj) is JavaObject:
            self.write_object(obj)
        elif type(obj) is str:
            self.write_blockdata(obj)
        else:
            raise RuntimeError("Object serialization of type %s is not supported." % str(type(obj)))

    def _writeStruct(self, unpack, length, args):
        ba = struct.pack(unpack, *args)
        self.object_stream.write(ba)

    def _writeString(self, string):
        len = len(string)
        self._writeStruct(">H", 2, (len, ))
        self.object_stream.write(string)

    def write_blockdata(self, obj, parent=None):
        # TC_BLOCKDATA (unsigned byte)<size> (byte)[size]
        self._writeStruct(">B", 1, (self.TC_BLOCKDATA, ))
        if type(obj) is str:
            self._writeStruct(">B", 1, (len(obj), ))
            self.object_stream.write(obj)

    def write_object(self, obj, parent=None):
        self._writeStruct(">B", 1, (self.TC_OBJECT, ))
        self._writeStruct(">B", 1, (self.TC_CLASSDESC, ))

class DefaultObjectTransformer(object):

    class JavaList(list, JavaObject):
        pass

    class JavaMap(dict, JavaObject):
        pass

    def transform(self, object):
        if object.get_class().name == "java.util.ArrayList":
            #    * @serialData The length of the array backing the <tt>ArrayList</tt>
            #    *             instance is emitted (int), followed by all of its elements
            #    *             (each an <tt>Object</tt>) in the proper order.
            new_object = self.JavaList()
            object.copy(new_object)
            new_object.extend(object.annotations[1:])
            return new_object
        if object.get_class().name == "java.util.LinkedList":
            new_object = self.JavaList()
            object.copy(new_object)
            new_object.extend(object.annotations[1:])
            return new_object
        if object.get_class().name == "java.util.HashMap":
            new_object = self.JavaMap()
            object.copy(new_object)

            for i in range((len(object.annotations)-1)/2):
                new_object[object.annotations[i*2+1]] = object.annotations[i*2+2]

            return new_object

        return object
