1 # Copyright (C) 2012 Nippon Telegraph and Telephone Corporation.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
7 # http://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
22 from . import packet_base
23 from . import ethernet
26 from ryu.lib.stringify import StringifyMixin
29 # Packet class dictionary
30 mod = inspect.getmembers(utils.import_module("ryu.lib.packet"),
31 lambda cls: (inspect.ismodule(cls)))
34 cl = inspect.getmembers(m,
36 inspect.isclass(cls) and
37 issubclass(cls, packet_base.PacketBase)))
38 cls_list.extend(list(cl))
39 PKT_CLS_DICT = dict(cls_list)
42 class Packet(StringifyMixin):
43 """A packet decoder/encoder class.
45 An instance is used to either decode or encode a single packet.
47 *data* is a bytearray to describe a raw datagram to decode.
48 When decoding, a Packet object is iteratable.
49 Iterated values are protocol (ethernet, ipv4, ...) headers and the payload.
50 Protocol headers are instances of subclass of packet_base.PacketBase.
51 The payload is a bytearray. They are iterated in on-wire order.
53 *data* should be omitted when encoding a packet.
56 # Ignore data field when outputting json representation.
57 _base_attributes = ['data']
59 def __init__(self, data=None, protocols=None, parse_cls=ethernet.ethernet):
60 super(Packet, self).__init__()
65 self.protocols = protocols
67 self._parser(parse_cls)
69 def _parser(self, cls):
72 # Ignores an empty buffer
73 if not six.binary_type(rest_data).strip(b'\x00'):
76 proto, cls, rest_data = cls.parser(rest_data)
80 self.protocols.append(proto)
81 # If rest_data is all padding, we ignore rest_data
82 if rest_data and six.binary_type(rest_data).strip(b'\x00'):
83 self.protocols.append(rest_data)
86 """Encode a packet and store the resulted bytearray in self.data.
88 This method is legal only when encoding a packet.
91 self.data = bytearray()
92 r = self.protocols[::-1]
93 for i, p in enumerate(r):
94 if isinstance(p, packet_base.PacketBase):
99 data = p.serialize(self.data, prev)
101 data = six.binary_type(p)
102 self.data = bytearray(data + self.data)
105 def from_jsondict(cls, dict_, decode_string=base64.b64decode,
108 for proto in dict_['protocols']:
109 for key, value in proto.items():
110 if key in PKT_CLS_DICT:
111 pkt_cls = PKT_CLS_DICT[key]
112 protocols.append(pkt_cls.from_jsondict(value))
114 raise ValueError('unknown protocol name %s' % key)
116 return cls(protocols=protocols)
118 def add_protocol(self, proto):
119 """Register a protocol *proto* for this packet.
121 This method is legal only when encoding a packet.
123 When encoding a packet, register a protocol (ethernet, ipv4, ...)
124 header to add to this packet.
125 Protocol headers should be registered in on-wire order before calling
129 self.protocols.append(proto)
131 def get_protocols(self, protocol):
132 """Returns a list of protocols that matches to the specified protocol.
134 if isinstance(protocol, packet_base.PacketBase):
135 protocol = protocol.__class__
136 assert issubclass(protocol, packet_base.PacketBase)
137 return [p for p in self.protocols if isinstance(p, protocol)]
139 def get_protocol(self, protocol):
140 """Returns the firstly found protocol that matches to the
143 result = self.get_protocols(protocol)
148 def __div__(self, trailer):
149 self.add_protocol(trailer)
152 def __truediv__(self, trailer):
153 return self.__div__(trailer)
156 return iter(self.protocols)
158 def __getitem__(self, idx):
159 return self.protocols[idx]
161 def __setitem__(self, idx, item):
162 self.protocols[idx] = item
164 def __delitem__(self, idx):
165 del self.protocols[idx]
168 return len(self.protocols)
170 def __contains__(self, protocol):
171 if (inspect.isclass(protocol) and
172 issubclass(protocol, packet_base.PacketBase)):
173 return protocol in [p.__class__ for p in self.protocols]
174 return protocol in self.protocols
177 return ', '.join(repr(protocol) for protocol in self.protocols)
178 __repr__ = __str__ # note: str(list) uses __repr__ for elements
181 # XXX: Hack for preventing recursive import
182 def _PacketBase__div__(self, trailer):
184 pkt.add_protocol(self)
185 pkt.add_protocol(trailer)
189 packet_base.PacketBase.__div__ = _PacketBase__div__
190 packet_base.PacketBase.__truediv__ = _PacketBase__div__