# UNI Protocol
Currently named after being mostly UNIdirectional, also UNI = sea urchin in japanese and they are awesome, please suggest another name if you have objections haha
# util.py
helper functions
```python
def split_file_into_superframes(filepath, frame_size):
data_fragments = []
try:
fd = open(filepath, "rb")
while True:
df = fd.read(frame_size)
if not df:
break
data_fragments.append(df)
fd.close()
except:
fd.close()
return None
return data_fragments
def fold_data(l, length):
filedata = bytes(l)
#for i in range(0, length):
# filedata += l[i].to_bytes(1, "big")
return filedata
def write_file(filename, data):
with open(filename, mode='wb') as f:
f.write(data)
```
# uni_defaults.py
default parameters
```python=
EC_K_VALUE = 20
EC_M_VALUE = 10
EC_TYPE = "liberasurecode_rs_vand"
SUPERFRAME_SIZE = 25600
# find a way to calculate this automatically!
# if K = 20, M = 10, then there will be 120 packets with 1360 bytes each!
# changing K will change the payload size! changing M will change the number of pkts that can be left out!
#UNI_TOTAL_PKTS = 120 (if 20,10, interleave matrix is 12r 10c)
#INTERLEAVE_ROWSIZE = 10
#if 20,6, interleave matrix is 13r, 8c)
INTERLEAVE_ROWSIZE = 10
UNI_TOTAL_PKTS = (EC_K_VALUE + EC_M_VALUE) * 4
# find a way to calculate this automatically!
UNI_PAYLOAD_SIZE = 1360
UNI_HEADER_SIZE = 4
```
# uni_protocol.py
main loop handling the send/recv/encode/decode
```python
from queue import Queue, Empty
import socket
import threading
from datetime import datetime
from enum import Enum
from pyeclib.ec_iface import ECDriver
import time
import utils
from uni_packet import UNIPacketType, UNIHeader, UNIPacket
import uni_defaults
class OpMode(Enum):
SendMode = 0
RecvMode = 1
class UNI:
def __init__(self, mtu=1500):
self.mtu = mtu
# we are using k=20 and m=10 parameter with Vandermonde-ReedSolomon coding, TODO: put k,m into the constructor
self.ec_driver = ECDriver(k=uni_defaults.EC_K_VALUE, m=uni_defaults.EC_M_VALUE, ec_type="liberasurecode_rs_vand")
self.last_update = datetime.timestamp(datetime.now())
self.started = False
self.packet_cache = {}
# bind socket as Sender role
def bind_as_sender(self, receiver_address):
self.mode = OpMode.SendMode
self.connection_manager = {}
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.receiver_address = receiver_address
self.lock = threading.Lock()
sender_packet_loop_thread = threading.Thread(target=self._sender_packet_loop)
sender_packet_loop_thread.setDaemon(True)
sender_packet_loop_thread.start()
# bind socket as Receiver role
def bind_as_receiver(self, receiver_address):
self.mode = OpMode.RecvMode
self.received_files_data = {}
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.bind(receiver_address)
self.file_received = Queue()
receiver_packet_loop_thread = threading.Thread(target=self._receiver_packet_loop)
receiver_packet_loop_thread.setDaemon(True)
receiver_packet_loop_thread.start()
# if connection is dropped for whatever reason
def drop(self):
if self.mode == OpMode.SendMode:
self.connection_manager.clear()
self.socket.close()
#=======================================================================================================
# Main loop for Sender
# When a reply is received from the receiver. Right now we only receive FIN packets to signal that the file has been received by remote side
def _sender_packet_loop(self):
if self.mode == OpMode.RecvMode:
raise Exception
while True:
try:
# cast the raw bytes to a packet
packet = UNIPacket()
packet.from_raw(self.socket.recv(2048))
if packet.header.file_id not in self.connection_manager:
continue
if packet.header.packet_type == UNIPacketType.Fin.value:
#print("[SND] received FIN from remote for file_id=%d !" % packet.header.file_id)
self.connection_manager[packet.header.file_id].put((True, packet.header.file_id))
except Exception as e: # recvが失敗した時とputが失敗した時は(適当)
if e == KeyboardInterrupt:
raise KeyboardInterrupt
else:
import traceback
traceback.print_exc()
# obtain the next fec sequence number
def get_next_sequence_new(self, current):
#x = current + 10
x = current + uni_defaults.INTERLEAVE_ROWSIZE
next = (x % uni_defaults.UNI_TOTAL_PKTS)
if x >= uni_defaults.UNI_TOTAL_PKTS:
next += 1
return next
# Sends a file
def send(self, filepath, file_id): # will block the thread
if self.mode == OpMode.RecvMode:
raise Exception
queue = Queue()
self.connection_manager[file_id] = queue # コネクションを登録 (Register the connection)
# first, split the file into superframes TODO: move hardcoded size elsewhere lol
data_fragments = utils.split_file_into_superframes(filepath, uni_defaults.SUPERFRAME_SIZE)
# then, encode each superframe with FEC, forming k amount of data shards and m amount of parity shards
enc_fragments = []
all_packets = []
for i in range(len(data_fragments)):
# make a new array for each superframe
enc_fragments.append([])
# encode the superframe into many shards (encode returns a 1d array)
enc_fragments[i] = self.ec_driver.encode(data_fragments[i])
for j in range(len(enc_fragments[i])):
# create the header for each packet in a superframe
header = UNIHeader()
# this is the last packet in the superframe (fec_seq=29)
#if j == len(enc_fragments[i]) - 1:
# # this is the last packet in the superframe AND file (fec_seq = 29 and block_id = 3)
# if i == 3:
# header.from_dict({ "packet_type": UNIPacketType.FileEnd.value, "block_id": i, "fec_seq": j, "file_id": file_id })
# else:
# header.from_dict({ "packet_type": UNIPacketType.FrameEnd.value, "block_id": i, "fec_seq": j, "file_id": file_id })
#else:
# header.from_dict({ "packet_type": UNIPacketType.Data.value, "block_id": i, "fec_seq": j, "file_id": file_id })
header.from_dict({ "packet_type": UNIPacketType.Data.value, "block_id": i, "fec_seq": j, "file_id": file_id })
# create the packet and put it into the list
packet = UNIPacket()
packet.from_dict({ "header": header, "payload": enc_fragments[i][j] })
all_packets.append(packet)
#self.packet_cache[(i, j)] = packet
# interleave the packets in each of enc_fragments by superframe (30 packets x 4 superframes)
# new: interleave ALL the packets, so it is 12 x 10 (12 rows, 10 cols)
"""
Since we know we will have 30 different packets per superframe, we can represent them in the following 5x6 matrix:
0 1 2 3 4 5
6 7 8 9 10 11
12 13 14 15 16 17
18 19 20 21 22 23
24 25 26 27 28 29
After that, we transmit it by reading down each column, from left to right, so:
0 -> 6 -> 12 -> 18 -> 24 -> 1 -> 7 -> 19 -> 25 etc.
This can be done with a simple modulo "wraparound" trick by ((x + 6) % 30) + 1, starting from 0
[ interleaved superframe 00 ] [ interleaved superframe 01 ] [ interleaved superframe 10 ] [ interleaved superframe 11 ]
"""
frame_count = 0
current_seq = 0
pkts_sent = 0
fin_time = 0
wait = False
#print(self.packet_cache)
while True:
try:
while wait == True:
try:
now = datetime.timestamp(datetime.now())
if now - fin_time > 0.5:
print("timeout!")
self.packet_cache = {} # reset packet cache
del(self.connection_manager[id]) # コネクションを解除
return
fin, id = queue.get(block=False) # 再送要求か受信完了報告か
if fin: # finish or timeout
print("fin, prep next!")
self.packet_cache = {} # reset packet cache
del(self.connection_manager[id]) # コネクションを解除
return
#elif sq < len(all_packets): # 再送要求
# retransmit_seq = max(sq, retransmit_seq)
#except Empty:
#timeout if FIN was lost :(
#NOTE: this part seems to prevent sending new stuff when jammer is on!
# now = datetime.timestamp(datetime.now())
# if now - fin_time > 0.01:
# self.packet_cache = {}
# return
# pass
# pass
except Exception as e: # キューが空の時
if e == KeyboardInterrupt:
raise KeyboardInterrupt
else:
break
with self.lock: # 複数のsendメソッドが並列に同時実行されている可能性があるため,ロックが必要
# get block and id from seq
#b_id = current_seq // (uni_defaults.EC_K_VALUE + uni_defaults.EC_M_VALUE)
#b_seq = current_seq - (b_id * (uni_defaults.EC_K_VALUE + uni_defaults.EC_M_VALUE))
#self.socket.sendto(self.packet_cache[(b_id, b_seq)].raw(), self.receiver_address)
self.socket.sendto(all_packets[current_seq].raw(), self.receiver_address) # パケット送信
# after we send the packet, get the next sequence
# interleave all 120 packets at once! :O
current_seq = self.get_next_sequence_new(current_seq)
pkts_sent += 1
#current_seq += 1
if pkts_sent >= uni_defaults.UNI_TOTAL_PKTS:
wait = True
fin_time = datetime.timestamp(datetime.now())
#self.packet_cache = {} # reset packet cache
#return
except Exception as e: # sendtoが失敗した時は(適当)
if e == KeyboardInterrupt:
raise KeyboardInterrupt
else:
import traceback
traceback.print_exc()
#==========================================================================================================
# Main loop for Receiver
# To receive packet, deinterleave, decode and then write to file
def initialize_new_array(self):
# make C-style array of fixed length (4 rows of 30 slots each)
raw_frames = [ 0 ] * 4
for i in range(len(raw_frames)):
raw_frames[i] = [ 0 ] * 30
return raw_frames
def _receiver_packet_loop(self):
if self.mode == OpMode.SendMode:
raise Exception
#raw_frames = self.initialize_new_array()
self.received_frames = {}
self.encoded_frames = {}
while True:
try:
data, from_addr = self.socket.recvfrom(uni_defaults.UNI_PAYLOAD_SIZE + uni_defaults.UNI_HEADER_SIZE)
self.started = True
packet = UNIPacket()
packet.from_raw(data)
#print("Rcv type: %d" % packet.header.packet_type)
# check the packet type
# if the packet is a data packet, add it to the list
#if packet.header.packet_type == UNIPacketType.FrameEnd.value or packet.header.packet_type == UNIPacketType.FileEnd.value or packet.header.packet_type == UNIPacketType.Data.value:
if packet.header.packet_type == UNIPacketType.Data.value:
# add the payload to the corresponding superframe/packet list
file_id = packet.header.file_id
block_id = packet.header.block_id
fec_seq = packet.header.fec_seq
#print("RX! fileid: %d | block: %d | fec: %d" % ( file_id, block_id, fec_seq ) )
# this is the first time this file id is seen, allocate memory for the FEC shards
if file_id not in self.encoded_frames:
self.encoded_frames[file_id] = {}
"""
file_id: {
blk0: [ <nShards=30> ], -> 1 shard = [ 1360 bytes ]
blk1: [ <nShards=30> ],
blk2: [ <nShards=30> ],
blk3: [ <nShards=30> ],
count: [ c0, c1, c2, c3 ]
}
"""
self.encoded_frames[file_id]["count"] = [ 0 ] * 4
# make arrays of 30 length
for blkid in range(4):
self.encoded_frames[file_id][blkid] = []
# check if the key exists, if not make it
if file_id not in self.received_frames:
self.received_frames[file_id] = {}
self.received_frames[file_id]["complete"] = False
self.received_frames[file_id]["received_blocks"] = {}
self.received_frames[file_id]["payload"] = []
# is this file complete?
if self.received_frames[file_id]["complete"] == False:
# add received encoded data to fec decoder
self.encoded_frames[file_id][block_id].append(packet.payload)
self.encoded_frames[file_id]["count"][block_id] += 1
# if we received more than the threshold for a superframe, try to decode it!
#if self.encoded_frames[file_id]["count"][block_id] >= uni_defaults.EC_K_VALUE and block_id not in self.received_frames[file_id]["received_blocks"]:
if self.encoded_frames[file_id]["count"][block_id] >= uni_defaults.EC_K_VALUE and block_id not in self.received_frames[file_id]["received_blocks"]:
#print(self.encoded_frames[file_id][block_id])
try:
decoded = self.ec_driver.decode(self.encoded_frames[file_id][block_id])
self.received_frames[file_id][block_id] = decoded
self.received_frames[file_id]["received_blocks"][block_id] = True
print("successful decode! fileid: %d | block: %d with %d frags" % (file_id, block_id, self.encoded_frames[file_id]["count"][block_id]) )
except:
pass
#print("not enough frags! fileid: %d | block: %d" % (file_id, block_id) )
# if we have received 4 superframes for this file_id, assemble the decoded payload
if len(self.received_frames[file_id]["received_blocks"]) == 4:
for i in range(4):
self.received_frames[file_id]["payload"] += self.received_frames[file_id][i]
self.received_frames[file_id]["complete"] = True
print("[RCV] fileid: %d complete!" % file_id)
# queue this file for writing
self.file_received.put( (file_id, len(self.received_frames[file_id]["payload"])) )
#self.file_received.put( ( datetime.timestamp(datetime.now()) ) )
#self.last_update = datetime.timestamp(datetime.now())
# tell the sender to prepare sending next file
self.reply(UNIPacketType.Fin.value, from_addr, packet.header.file_id)
else:
# tell the sender to prepare sending next file
self.reply(UNIPacketType.Fin.value, from_addr, packet.header.file_id)
except Exception as e: # recvが失敗した時とputが失敗した時は(適当)
if e == KeyboardInterrupt:
raise KeyboardInterrupt
else:
import traceback
traceback.print_exc()
def recv(self):
if self.mode == OpMode.SendMode:
raise Exception
#while True:
# now = datetime.timestamp(datetime.now())
# if now - self.last_update > 5 and self.started == True:
# key, length = self.file_received.get()
# break
# else:
# time.sleep(0.1)
key, length = self.file_received.get()
return utils.fold_data(self.received_frames[key]["payload"], length)
# For receiver to reply (only used for sending FIN)
def reply(self, packet_type, addr, file_id):
if self.mode == OpMode.SendMode:
raise Exception
if packet_type == UNIPacketType.Fin.value:
header = UNIHeader()
header.from_dict({ "packet_type": packet_type, "file_id": file_id, "block_id": 0, "fec_seq": 0 })
packet = UNIPacket()
packet.from_dict({ "header": header, "payload": b'', })
self.socket.sendto(packet.raw(), addr)
```
# uni_packet.py
header definitions for UNI protocol, subject to changes
```python=
from enum import Enum
# 0 1 2
# 0 1 2 3 4 5 6 7 8 9 A B C D E F 0 1 2 3 4 5 6 7 8 9 A B C D E F 0
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | A | B | C | RSV | FileID |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# A (2 bits) -> Packet Type
# B (2 bits) -> Block ID (00 to 11)
# C (5 bits) -> FEC Sequence per Block
# RSV (7 bits) -> Reserved
# FileID (16 bits) -> file that this packet belongs to
# UNI packet header
# CONSTANTS
UNI_HEADER_LENGTH = 4
class UNIPacketType(Enum):
Data = 0 # any data that does not mark end of superframe/file
FrameEnd = 1 # last packet of superframe, tells the receiver that they can begin to decode this superframe
Retx = 2 # retransmit missing frag
Fin = 3
class UNIHeader:
# def __init__(self, id, seq):
# self.id = id
# self.seq = seq
def from_raw(self, raw):
header = int.from_bytes(raw[0:UNI_HEADER_LENGTH], "big")
# use bitmasks to extract the bits corresponding to each field
self.packet_type = header >> 30 # mask not needed for leftmost field
self.block_id = (header >> 28) & 3 # 3 = 0b11 = 2 masked bits
self.fec_seq = (header >> 23) & 31 # 31 = 0b11111 = 5 masked bits
self.file_id = header & 65535 # 65535 = 0b1111 1111 1111 1111 = 16 masked bits
def raw(self):
# use bitwise OR to concatenate fields to a byte
header = (self.packet_type << 30) | (self.block_id << 28) | (self.fec_seq << 23) | self.file_id
return header.to_bytes(4, "big")
def from_dict(self, dict):
self.packet_type = dict["packet_type"]
self.block_id = dict["block_id"]
self.fec_seq = dict["fec_seq"]
self.file_id = dict["file_id"]
class UNIPacket:
# def __init__(self, header, payload):
# self.header = header
# self.payload = payload
def from_raw(self, raw):
header = UNIHeader()
header.from_raw(raw[0:UNI_HEADER_LENGTH])
self.header = header
self.payload = raw[UNI_HEADER_LENGTH:]
def raw(self):
raw = self.header.raw()
raw += self.payload
return raw
def from_dict(self, dict):
self.header = dict["header"]
self.payload = dict["payload"]
```