###### tags: `ebpf`
# offset type system inference algorithm
```
/*
language supported (so far):
reg ::= unsigned int
stmt ::= reg = *(type *)(reg aop const)
| *(type *)(reg aop const) = reg
| reg = reg aop val
| assume reg rop reg;
type ::= u64 | u32 | u16 | u8
aop ::= + | -
rop ::= > | <=
val ::= reg | const
*/
```
## Constants
```
int STACK_BEGIN = 0;
int CTX_BEGIN = 0;
int PACKET_BEGIN = 0;
int SHARED_BEGIN = 0;
int PACKET_END = -4100;
int PACKET_META = -1;
int MAX_PACKET_SIZE = 0xffff;
```
## Given
```
// assuming packet_begin, packet_end, packet_meta locations in context and context size have been given
extern int PACKET_BEGIN_LOC;
extern int PACKET_END_LOC;
extern int PACKET_META_LOC;
extern int CTX_SIZE;
```
## Definitions
```
using region_t = enum {T_PACKET, T_CTX, T_SHARED, T_STACK};
using ptr_no_off_t = struct {
region_t m_r;
};
using ptr_with_offset_t = struct {
region_t m_r;
int m_offset;
};
using interval_t = std::pair<int, int>;
using weight_t = interval_t; // should be interval_t | symbol_t
using slack_var_t = int;
using rop_t = enum {R_GT, R_GE, R_LT, R_LE};
using dist_t = struct {
slack_var_t m_slack;
weight_t m_dist;
};
// if m_dist is +ve, represents `begin+m_dist+s;` where s is some slack variable represented by a non-negative number m_slack,
// if m_dist is -ve, represents `end+m_dist+1` where m_slack is -1
using dist_equality_t = std::pair<dist_t, dist_t> // represents constraint `p[0] = p[1];`, e.g., `begin+8+s = end`
using inequality_t = struct {
slack_var_t m_slack;
rop_t m_rel;
weight_t m_value;
};
// represents `s m_rel m_value;` constraint, where e.g., `s >= [0,0]` and represents a contraint corresponding to dist_t constraint
using packet_constraint_t = std::pair<dist_equality_t, inequality_t>;
// actual implementation
// using live_registers_t = std::array<reg_with_loc_t, 11>;
// using global_offset_env_t = std::unordered_map<reg_with_loc_t, dist_t>;
// using registers_state_t = struct {
// live_registers_t m_cur_def;
// std::shared_ptr<global_offset_env_t> m_offset_env;
// }
using registers_state_t = std::array<dist_t, 11>;
// represents `r{n} = dist : dist_t;`, where {n} \belongs [0,10], e.g., `r1 = begin+8`
using stack_state_t = std::unordered_map<int, dist_t>; // represents `sp[n] = dist;`, where n \belongs [0,511], e.g., `sp[508] = begin+16`
using extra_constraints_t = struct {
packet_constraint_t m_meta_and_begin;
packet_constraint_t m_begin_and_end;
}
using ctx_dists_t = std::unordered_map<unsigned int, dist_t>;
// represents `cp[n] = dist;`
using ctx_t = struct {
ctx_dists_t m_dists;
int m_size;
}
```
## State
```
using state_t = struct {
register_state_t m_reg_state;
stack_state_t m_stack_state;
extra_constraints_t m_extra_constraints;
ctx_t m_ctx_state;
slack_var_t m_slack = 0;
};
```
## Initializations
```
std::unordered_map<cfg_node, state_t> states;
states[0].m_ctx_state.m_dists = ctx_t{
PACKET_BEGIN_LOC: dist_t{slack: null, dist: 0},
PACKET_META_LOC: dist_t{slack: null, dist: -1},
PACKET_END_LOC: dist_t{slack: null, dist: -4100}
};
states[0].m_ctx_state.m_size = CTX_SIZE;
```
## Algorithm
```
def equal_dists(dist_t dist1, dist_t dist2):
if dist1.slack is null and dist2.slack is null:
// case for registers and stack slots
if (dist1.dist >= 0 and dist2.dist >= 0) or (dist1.dist < 0 and dist2.dist < 0):
return dist1.dist == dist2.dist;
else:
// when one register/stack slot is defined as forward distance and other with backward; not common case but might happen
else if dist1.slack is not null and dist2.slack is not null:
// we do not get to this branch for registers and stack slots
return false;
else:
// when one distance contains slack and other doesn't; very likely we just need to return not equal but might have to do other analysis
return false;
def operator||(stack_dists_t stack_dists1, stack_dists_t stack_dists2):
stack_dists_t stack_dists_joined; // bottom
for (stack_slot, stack_slot_dist) in st1.m_stack_dists:
if stack_slot in st2.m_stack_dists and equal_dists(stack_slot_dist, st2.m_stack_dists[stack_slot]):
stack_dists_joined[stack_slot] = stack_slot_dist;
return stack_dists_joined;
def operator||(registers_state_t reg_state1, registers_state_t reg_state2):
register_dists_t reg_dists_joined; // bottom
for (reg, reg_dist) in st1.m_reg_dists:
if reg in st2.m_reg_dists and equal_dists(reg_dist, st2.m_reg_dists[reg]):
reg_dists_joined[reg] = reg_dist;
return reg_dists_joined;
def operator||(extra_constraints_t extra_cnstrnts1, extra_constraints_t extra_cnstrnts2):
// represents begin+n+s = end+m;
forward_and_backward_eq_t eq1 = extra_cnstrnts1[0];
forward_and_backward_eq_t eq2 = extra_cnstrnts2[0];
// s >= p
inequality_t ineq1 = extra_cnstrnts1[1];
inequality_t ineq2 = extra_cnstrnts2[1];
extra_constraints_t cnstrnts_joined; // bottom
// normalize end+m by moving m to lhs, begin+n-m+s = end; for both systems;
eq1[0].dist = eq1[0].dist-eq1[1].dist-1;
eq1[1].dist = -1;
eq2[0].dist = eq2[0].dist-eq2[1].dist-1;
eq2[1].dist = -1;
// normalize s >= p by adding p to lhs of equality, begin+n-m+p+s = end, s >= 0;
eq1[0].dist = eq1[0].dist+ineq1.value;
ineq1.dist = 0;
eq2[0].dist = eq2[0].dist+ineq2.value;
ineq2.dist = 0;
// given two normalized systems:
// 1. begin+a+s = end, s >= 0
// 2. begin+b+t = end, t >= 0
// should be true that both equalities contain slack
if eq1[0].slack is not null and eq2[0].slack is not null:
slack_var_t s;
if eq1[0].slack == eq2[0].slack: // s == t
s = eq1[0].slack;
else:
s = s+string(global_slack_counter++); // new slack variable u
// begin+min(a,b)+u = end
dist_t f = dist_t{slack: s, dist: min(eq1[0].dist, eq2[0].dist)};
dist_t b = eq1[1];
forward_and_backward_eq_t eq = std::make_pair(f, b);
// s >= 0
inequality_t ineq = inequality_t{slack: s, rel: ineq1.rel, value: 0};
cnstrnts_joined = std::make_pair(eq, ineq);
return cnstrnts_joined;
def operator||(state_t st1, state_t st2):
state_t st_joined = _|_; // bottom
st_joined.m_reg_state = st1.m_reg_state || st2.m_reg_state;
st_joined.m_stack_state = st1.m_stack_state || st2.m_stack_state;
st_joined.m_extra_constraints = st1.m_extra_constraints || st2.m_extra_constraints;
st_joined.m_slack = std::max(st1.m_slack, st2.m_slack);
return st_joined;
def join(cfg_node bb):
state_t state = _|_; // state is bottom
cfg_node[] all_preds = cfg.get_all_predecessors(bb);
for i in [0..all_preds.length()]:
if i == 0:
state = states[all_preds[i]];
else:
state = state || states[all_preds[i]];
return state;
```
```
for bb in cfg.topological_sort():
unsigned int global_slack_counter = 0;
if cfg.num_of_predecessors(bb) == 1:
state = states[cfg.get_predecessor(bb)];
else if cfg.num_of_predecessors(bb) > 1:
state = join_at(bb);
for instr in bb:
match(instr) with:
```
```
va = *(T *)(vb arithm_oprtr vc) ->
ptr_t vb_type = get_type_of_register(vb); // interact with type domain
if vb_type is ptr_with_offset_t:
ptr_with_offset_t vb_type_with_offset = (ptr_with_offset_t)vb_type;
int load_at = vb_type_with_offset.offset + vc;
if vb_type_with_offset.region is T_CTX:
if load_at not in ctx_dists:
state.m_reg_dists -= va; // forget the offset of va, as we do not know anything about it
else:
state.m_reg_dists[va] = ctx_dists[load_at];
else if vb_type_with_offset.region is T_STACK:
if load_at not in state.m_stack_dists:
state.m_reg_dists -= va;
else:
state.m_reg_dists[va] = state.m_stack_dists[load_at];
else:
state.m_reg_dists -= va;
// case when base register type is packet or shared pointer. we do not store any pointers there
```
```
*(T *)(va arithm_oprtr vb) = vc ->
ptr_t va_type = get_type_of_register(va);
ptr_t? vc_type = get_type_of_register(vc); // ptr_t or null
if vc_type is null:
return; // vc is a number
if va_type is ptr_with_offset_t:
ptr_with_offset_t va_type_with_offset = (ptr_with_offset_t)va_type;
int store_at = va_type_with_offset.offset+vb;
if va_type_with_offset.region is T_STACK:
if vc_type is null:
state.m_stack_dists -= store_at;
else:
state.m_stack_dists[store_at] = state.m_reg_dists[vc];
else:
// type is ctx pointer, we do not store any pointers into ctx
else:
// type is either shared or packet pointer, we do not store any pointers into packet or shared region
```
```
va = vb ->
ptr_t? vb_type = get_type_of_register(vb);
if vb_type is null:
state.m_reg_dists -= va;
else:
state.m_reg_dists[va] = state.m_reg_dists[vb];
```
```
va = vb arithm_oprtr vc ->
ptr_t? vb_type = get_type_of_register(vb); // ptr_t or null
// vb is a number
if vb_type is null:
state.m_reg_dists -= va; // rb is a number; forget type of va
return;
// vb_type is packet pointer
if vb_type.region is packet:
dist_t vb_dist = state.m_reg_dists[vb];
if vb_dist is constant_t:
weight_t d = vb_dist.dist+vc;
state.m_reg_dists[va] = dist_t{slack: vb_dist.slack, dist: d};
else:
// not handling the case with non-numeric offsets
// all other cases
else:
state.m_reg_dists -= va; // vb is stack/ctx/shared pointer; forget type of va
return;
```
```
assume va <= vb ->
if va not in state.m_reg_dists or vb not in state.m_reg_dists:
return;
dist_t va_dist = state.m_reg_dists[va];
dist_t vb_dist = state.m_reg_dists[vb];
slack_var_t s = s+string(global_slack_counter++);
dist_t f = dist_t{slack: s, dist: va_dist.dist};
dist_t b = dist_t{slack: null, dist: vb_dist.dist};
state.m_extra_constraints[0] = forward_and_backward_eq_t{std::make_pair(f, b)};
state.m_extra_constraints[1] = inequality_t{slack: s, rel: R_GE, 0};
```
```
states[bb] = state;
```
# An older version of the Offset inference
```
/*
language supported (so far):
reg ::= unsigned int
stmt ::= reg = *(type *)(reg aop const)
| *(type *)(reg aop const) = reg
| reg = reg
| reg = reg aop val
| assume reg rop reg;
type ::= u64 | u32 | u16 | u8
aop ::= + | -
rop ::= > | <=
val ::= reg | const
*/
```
```
using region_t = enum {T_PACKET, T_CTX, T_SHARED, T_STACK};
using register_t = unsigned int; // 0-10 values, representing r0-r10
struct {
region_t region;
} ptr_t;
struct : ptr_t {
int offset;
} ptr_with_offset_t;
```
```
using relation_t = enum {R_EQ, R_GT, R_LE};
using location_t = enum {REG, STACK, CTX};
using constant = int; // define a domain for constants
using weight_t = constant | symbol; // should be constants + symbols
using vert_id = unsigned int;
struct node_t {
vert_id m_node_id;
}
struct edge_t {
node_t m_from;
node_t m_to;
weight_t m_dist;
relation_t m_relation;
}
using register_offset_t = std::array<node_t, 11>;
using stack_offsets_t = std::unordered_map<int, node_t>;
using ctx_offsets_t = std::unordered_map<int, node_t>;
using out_edges_t = std::unordered_map<node_t, std::vector<edge_t>>;
using stores_pointer_t = std::pair<location_t, unsigned int>; // either REG+register_number, STACK+stack_offset, CTX+ctx_offset
using node_pointed_by_t = std::unordered_map<node_t, std::vector<stores_pointer_t>>; // locations the node is being pointed by; could be multiple
struct offset_graph_t {
out_edges_t m_out_edges; // outgoing edges from each node
node_pointed_by_t m_to_node_mappings;
}
```
```
struct state_t = {
register_offset_t m_curr_reg_offsets; // mapping of registers to nodes
stack_offsets_t m_curr_stack_offsets; // mapping of stack offsets to nodes
offset_graph_t m_offset_graph;
};
map<cfg_node, state_t> states;
// node_t(0) represents packet_start, node_t(1) represents packet_end
ctx_offsets_t ctx_offsets = map{76: node_t(0), 80: node_t(1)};
state_t state0;
state0.m_offset_graph.m_to_node_mappings[node_t(0)].insert(std::pair(CTX, 76));
state0.m_offset_graph.m_to_node_mappings[node_t(1)].insert(std::pair(CTX, 80));
```
```
// given two nodes and their distances from their respective start nodes in two graphs, check if distances and mappings to nodes match
def look_for_node_with_same_dist(state_t st, node_t curr_node, weight_t curr_dist, node_t to_search_node, node_t to_search_dist, stores_pointer_t p):
if curr_dist is not constant or to_search_dist is not constant:
return false;
if curr_dist == to_search_dist:
if p in st.m_offset_graph.m_to_node_mappings[curr_node]:
return true;
else if curr_dist < to_search_dist:
for out in st.m_offset_graph.m_out_edges[curr_node]:
if look_for_node_with_same_dist(st, out.m_to, curr_dist+out.m_dist, to_search_edge, p):
return true;
return false;
else:
return false;
// given a node and distance from start node in one graph, search for a node with same distance in the other graph
def search_in_second_graph(state_t st, node_t to_search_node, weight_t to_search_dist, stores_pointer_t p):
data_node = node_t(0);
for out_edge in st.m_offset_graph.m_out_edges[data_node]:
if look_for_node_with_same_dist(st, out_edge.m_to, out_edge.m_dist, to_search_node, to_search_dist, p):
return true;
return false;
def operator||(state_t st1, state_t st2):
state_t state = _|_;
node_t data_node = node_t(0);
node_t data_end_node = node_t(1);
// for data_node and data_end_node, we only need to check if node mappings are consistent in both states, no traversals needed
for node in std::vector{data_node, data_end_node}:
for pointed_by in st1.m_offset_graph.m_to_node_mappings[node]:
if pointed_by in st2.m_offset_graph.m_to_node_mappings[node]:
state.m_offset_graph.m_to_node_mappings[node].insert(pointed_by);
if pointed_by[0] == REG:
state.m_curr_reg_offsets[pointed_by[1]] = node;
else:
state.m_curr_stack_offset[pointed_by[1]] = node;
std::vector<node_t> worklist = std::vector{std::pair(data_node, 0)};
// join the two graphs
for (node, dist) in worklist:
for out_edge in st1.m_offset_graph.m_out_edges[node]:
weight_t dist_from_start = out_edge.m_dist+dist;
worklist.push_back(std::pair(out_edge.m_to, dist_from_start));
if out_edge.m_relation == R_EQ: // right now handling only equalities
state.m_offset_graph.m_out_edges[node] = out_edge;
for pointed_by in st1.m_offset_graph.m_to_node_mappings[out_edge.m_to]:
if search_in_second_graph(st2, out_node, dist_from_start, pointed_by):
state.m_offset_graph.m_to_node_mappings[node].insert(p);
if pointed_by[0] == REG:
state.m_curr_reg_offsets[pointed_by[1]] = out_edge.m_to;
else:
state.m_curr_stack_offset[pointed_by[1]] = out_edge.m_to;
else:
// handle inequalities
return state;
```
```
def join(cfg_node bb):
state_t state = _|_; // state is bottom
cfg_node[] all_preds = cfg.get_all_predecessors(bb);
for i in [0..all_preds.length()]:
if i == 0:
state = states[all_preds[i]]
else:
state = state || states[all_preds[i]];
return state;
```
```
def add_node(node_t curr_node, register_t dest_reg, weight_t dist):
if dist is constant:
for edge in state.m_offset_graph.m_out_edges[curr_node]:
if edge.m_dist is constant and edge.m_relation is R_EQ:
if dist < edge.m_dist:
node_t dest_node = node_t(global_node_counter++);
edge_t e = {m_from: curr_node, m_to: dest_node, m_dist: dist, m_relation: R_EQ};
edge_t e1 = {m_from: dest_node, m_to: edge.m_to, m_dist: edge.m_dist-dist, m_relation: R_EQ};
state.m_offset_graph.m_out_edges[curr_node].insert(e);
state.m_offset_graph.m_out_edges[dest_node].insert(e1);
state.m_offset_graph.m_out_edges[curr_node].erase(edge);
state.m_curr_reg_offsets[dest_reg] = dest_node;
state.m_offset_graph.m_to_node_mappings[dest_node].insert(std::pair(REG,dest_reg));
else if dist == edge.m_dist:
state.m_curr_reg_offsets[dest_reg] = edge.m_to;
state.m_offset_graph.m_to_node_mappings[edge.m_to].insert(std::pair(REG,dest_reg));
else:
add_node(m_to, dest_reg, dist-edge.m_dist);
else:
node_t dest_node = node_t(global_node_counter++);
edge_t e = {m_from: curr_node, m_to: dest_node, m_dist: dist, m_relation: R_EQUAL};
```
```
for bb in cfg.topological_sort():
vert_id global_node_counter = 2;
if cfg.num_of_predecessors(bb) == 1:
state = states[cfg.get_predecessor(bb)];
else if cfg.num_of_predecessors(bb) > 1:
state = join_at(bb);
for instr in bb:
match(instr) with:
```
```
va = *(T *)(vb arithm_oprtr vc) ->
ptr_t vb_type = get_type_of_register(vb); // interact with type domain
if vb_type is ptr_with_offset_t:
ptr_with_offset_t vb_type_with_offset = (ptr_with_offset_t)vb_type;
int load_at = vb_type_with_offset.offset + vc;
if vb_type_with_offset.region is T_CTX:
if load_at not in ctx_offsets:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_reg_offsets[va]].erase(std::pair(REG,va));
state.m_curr_reg_offsets -= va; // forget the offset of va, as we do not know anything about it
else:
state.m_curr_reg_offsets[va] = ctx_offsets[load_at];
else if vb_type_with_offset.region is T_STACK:
if load_at not in state.m_curr_stack_offsets:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_reg_offsets[va]].erase(std::pair(REG,va));
state.m_curr_reg_offsets -= va;
else:
node_t n = state.m_curr_stack_offsets[load_at];
state.m_curr_reg_offsets[va] = n;
state.m_offset_graph.m_to_node_mappings[n].insert(std::pair(REG,va));
else:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_reg_offsets[va]].erase(std::pair(REG,va));
state.m_curr_reg_offsets -= va;
// case when base register type is packet or shared pointer. we do not store any pointers there
```
```
*(T *)(va arithm_oprtr vb) = vc ->
ptr_t va_type = get_type_of_register(va);
ptr_t? vc_type = get_type_of_register(vc); // ptr_t or null
if vc_type is null:
return; // vc is a number
if va_type is ptr_with_offset_t:
ptr_with_offset_t va_type_with_offset = (ptr_with_offset_t)va_type;
int store_at = va_type_with_offset.offset+vb;
if va_type_with_offset.region is T_STACK:
if vc_type is null:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_stack_offsets[store_at]].erase(std::pair(STACK,store_at));
state.m_curr_stack_offsets -= store_at;
else:
node_t n = state.m_curr_reg_offsets[vc];
state.m_curr_stack_offsets[store_at] = n;
state.m_offset_graph.m_to_node_mappings[n].insert(std::pair(STACK,store_at));
else:
// type is ctx pointer, we do not store any pointers into ctx
else:
// type is either shared or packet pointer, we do not store any pointers into packet or shared region
```
```
va = vb ->
ptr_t? vb_type = get_type_of_register(vb);
if vb_type is null:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_reg_offsets[va]].erase(std::pair(REG,va));
state.m_curr_reg_offsets -= va;
else:
node_t n = state.m_curr_reg_offsets[vb];
state.m_curr_reg_offsets[va] = n;
state.m_offset_graph.m_to_node_mappings[n].insert(std::pair(REG,va));
```
```
va = vb arithm_oprtr vc ->
// remember, vc might be a constant or a known/unknown register
ptr_t? vb_type = get_type_of_register(vb); // ptr_t or null
// vb is a number
if vb_type is null:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_reg_offsets[va]].erase(std::pair(REG,va));
state.m_curr_reg_offsets -= va; // rb is a number; forget type of va
return;
// vb_type is packet pointer
if vb_type.region is packet:
node_t vb_node = state.m_curr_reg_offsets[vb];
add_node(vb_node, va, vc); // add a new node for va
// all other cases
else:
state.m_offset_graph.m_to_nodes_mappings[state.m_curr_reg_offsets[va]].erase(std::pair(REG,va));
state.m_curr_reg_offsets -= va; // rb is stack/ctx/shared pointer; forget type of va
return;
```
```
assume va <= vb ->
if va not in state.m_curr_reg_offsets or vb not in state.m_curr_reg_offsets:
return;
node_t va_node = state.m_curr_reg_offsets[va];
node_t vb_node = state.m_curr_reg_offsets[vb];
state.m_offset_graph.m_out_edges[va_node].insert({m_from: va_node, m_to: vb_node, m_dist: 0, m_relation: R_LE});
```
```
assume va > vb ->
if va not in state.m_curr_reg_offsets or vb not in state.m_curr_reg_offsets:
return;
node_t va_node = state.m_curr_reg_offsets[va];
node_t vb_node = state.m_curr_reg_offsets[vb];
state.m_offset_graph.m_out_edges[va_node].insert({m_from: va_node, m_to: vb_node, m_dist: 0, m_relation: R_GT});
states[bb] = state;
```