2
0
mirror of https://github.com/mozilla/cipherscan.git synced 2024-11-05 07:23:42 +01:00
cipherscan/cscan/extensions.py

199 lines
6.5 KiB
Python
Raw Normal View History

# Copyright 2016(c) Hubert Kario
# This work is released under the Mozilla Public License Version 2.0
"""Extra TLS extensions."""
import tlslite.extensions
from tlslite.utils.codec import Writer
from tlslite.utils.compat import b2a_hex
from .constants import ExtensionType, GroupName
import .messages
# make TLSExtensions hashable (__eq__ is already defined in base class)
tlslite.extensions.TLSExtension.__hash__ = lambda self: hash(self.extType) ^ \
hash(bytes(self.extData))
class RenegotiationExtension(tlslite.extensions.TLSExtension):
"""Secure Renegotiation extension RFC 5746."""
def __init__(self):
"""Initialize secure renegotiation extension."""
super(RenegotiationExtension, self).__init__(
extType=ExtensionType.renegotiation_info)
self.renegotiated_connection = None
def create(self, data):
"""Set the value of the Finished message."""
self.renegotiated_connection = data
@property
def extData(self):
"""Serialise the extension."""
if self.renegotiated_connection is None:
return bytearray(0)
writer = Writer()
writer.addVarSeq(self.renegotiated_connection, 1, 1)
return writer.bytes
def parse(self, parser):
"""Deserialise the extension from binary data."""
if parser.getRemainingLength() == 0:
self.renegotiated_connection = None
return
self.renegotiated_connection = parser.getVarBytes(1)
return self
def __repr__(self):
"""Human readable representation of extension."""
return "RenegotiationExtension(renegotiated_connection={0!r})"\
.format(self.renegotiated_connection)
def __format__(self, formatstr):
"""Formatted representation of extension."""
data = messages.format_bytearray(self.renegotiated_connection,
formatstr)
return "RenegotiationExtension(renegotiated_connection={0})"\
.format(data)
tlslite.extensions.TLSExtension._universalExtensions[
ExtensionType.renegotiation_info] = RenegotiationExtension
class SessionTicketExtension(tlslite.extensions.TLSExtension):
"""Session Ticket extension (a.k.a. OCSP staple)."""
def __init__(self):
"""Create Session Ticket extension."""
super(SessionTicketExtension, self).__init__(
extType=ExtensionType.session_ticket)
self.data = bytearray(0)
def parse(self, parser):
"""Deserialise the extension from binary data."""
self.data = parser.bytes
return self
def __format__(self, formatstr):
"""Print extension data in human-readable form."""
data = messages.format_bytearray(self.data, formatstr)
return "SessionTicketExtension(data={0})".format(data)
tlslite.extensions.TLSExtension._universalExtensions[
ExtensionType.session_ticket] = SessionTicketExtension
class ServerStatusRequestExtension(tlslite.extensions.TLSExtension):
"""Server Status Request extension."""
def __init__(self):
"""Create server status request extension."""
super(ServerStatusRequestExtension, self).__init__(
extType=ExtensionType.status_request)
def parse(self, parser):
"""Deserialise the extension from binary data."""
if parser.getRemainingLength() != 0:
raise SyntaxError() # FIXME
return self
def __repr__(self):
"""Human readable representation of the object."""
return "ServerStatusRequestExtension()"
tlslite.extensions.TLSExtension._serverExtensions[
ExtensionType.status_request] = ServerStatusRequestExtension
class KeyShareExtension(tlslite.extensions.TLSExtension):
"""TLS1.3 extension for handling key negotiation."""
def __init__(self):
"""Create key share extension object."""
super(KeyShareExtension, self).__init__(
extType=ExtensionType.key_share)
self.client_shares = None
def create(self, shares):
"""
Set the list of key shares to send.
@type shares: list of tuples
@param shares: a list of tuples where the first element is a NamedGroup
ID while the second element in a tuple is an opaque bytearray encoding
of the key share.
"""
self.client_shares = shares
return self
@property
def extData(self):
"""Serialise the extension."""
if self.client_shares is None:
return bytearray(0)
writer = Writer()
for group_id, share in self.client_shares:
writer.add(group_id, 2)
if group_id in GroupName.allFF:
share_length_length = 2
else:
share_length_length = 1
writer.addVarSeq(share, 1, share_length_length)
ext_writer = Writer()
ext_writer.add(len(writer.bytes), 2)
ext_writer.bytes += writer.bytes
return ext_writer.bytes
def parse(self, parser):
"""Deserialise the extension."""
if parser.getRemainingLength() == 0:
self.client_shares = None
return
self.client_shares = []
parser.startLengthCheck(2)
while not parser.atLengthCheck():
group_id = parser.get(2)
if group_id in GroupName.allFF:
share_length_length = 2
else:
share_length_length = 1
share = parser.getVarBytes(share_length_length)
self.client_shares.append((group_id, share))
return self
def __repr__(self):
"""Human readble representation of extension."""
return "KeyShareExtension({0!r})".format(self.client_shares)
def __format__(self, formatstr):
"""Formattable representation of extension."""
if self.client_shares is None:
return "KeyShareExtension(None)"
verbose = ""
hexlify = False
if 'v' in formatstr:
verbose = "GroupName."
if 'h' in formatstr:
hexlify = True
shares = []
for group_id, share in self.client_shares:
if hexlify:
share = b2a_hex(share)
else:
share = repr(share)
shares += ["({0}{1}, {2})".format(verbose,
GroupName.toStr(group_id),
share)]
return "KeyShareExtension([" + ",".join(shares) + "])"
tlslite.extensions.TLSExtension._universalExtensions[
ExtensionType.key_share] = KeyShareExtension