# Memory Forensics on Torch
My source code and everything else required to run the project can be found in my [github](https://github.com/meow87wang/ml-forensics).
## 1. Build the CNN model
<!-- https://www.kaggle.com/code/shadabhussain/cifar-10-cnn-using-pytorch#Defining-the-Model-(Convolutional-Neural-Network) -->
Using [this](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) as a reference, I constructed a CNN with the following architecture:
```python
self.network = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
nn.Flatten(),
nn.Linear(256*4*4, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 10))
```
The training result is:
```
Load and normalize CIFAR10
shape: torch.Size([3, 32, 32])
Define a Convolutional Neural Network
####################
torch version: 2.9.0+cu128
random_seed=904187617
num_epochs=10
batch_size=128
opt_func=<class 'torch.optim.adam.Adam'>
lr=0.001
Epoch [0], train_loss: 1.6993, val_loss: 1.3804, val_acc: 0.4908
Epoch [1], train_loss: 1.1624, val_loss: 1.0466, val_acc: 0.6123
Epoch [2], train_loss: 0.9035, val_loss: 0.8601, val_acc: 0.6928
Epoch [3], train_loss: 0.7205, val_loss: 0.7523, val_acc: 0.7393
Epoch [4], train_loss: 0.5847, val_loss: 0.7473, val_acc: 0.7512
Epoch [5], train_loss: 0.4653, val_loss: 0.7476, val_acc: 0.7600
Epoch [6], train_loss: 0.3711, val_loss: 0.7276, val_acc: 0.7691
Epoch [7], train_loss: 0.2817, val_loss: 0.8300, val_acc: 0.7607
Epoch [8], train_loss: 0.2200, val_loss: 0.8787, val_acc: 0.7541
Epoch [9], train_loss: 0.1609, val_loss: 0.9389, val_acc: 0.7711
```
## 2. Export the model
The [url](https://docs.pytorch.org/tutorials/advanced/cpp_export.html) is deprecated before I started the project. But from this [page](https://docs.pytorch.org/cppdocs/#torchscript) I can know that it references to TorchScript. So I use this code to export my model.
```python
m = torch.jit.script(model)
torch.jit.save(m, './cifar10-cnn-jit.pt')
```
## 3. C++ program
I use `torch::jit::load` to load my modul. To verify memory analysis results later, I used the [`Module::dump`](https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/api/module.cpp#L616) function as a ground-truth reference.
```c++
#include <ATen/core/jit_type_base.h>
#include <torch/script.h>
#include <iostream>
int main(int argc, const char* argv[]) {
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "Model loaded successfully.\n";
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::rand({64,3,32,32}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << "Output: " << output << "\n";
return 0;
}
```
## 4. Reverse Engineering the Model Object via GDB
I utilized GDB (with the GEF extension) to dive into the `torch::jit::script::Module` object.
### `torch::jit::script::Module`
Inherited from `torch::jit::Object`, the [source code](https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/api/module.h#L87) reveals three primary data members: `mem_to_delete_`, `traced_inputs_`, and `register_mutex_`. Through manual memory manipulation (setting members to zero and calling `dump`), I confirmed that the essential model information is stored within the `Object` base class rather than these auxiliary members.
And the data structure of `module` is:
```
0x00-0x08: <torch::jit::Object> (_ivalue_.target_)
0x08-0x18: mem_to_delete_
0x18-0x20: traced_inputs_
0x20-0x30: register_mutex_
```
```shell
gef➤ p &module
$7 = (torch::jit::script::Module *) 0x7fffffffda70
gef➤ call (void)memset(0x7fffffffda78, 0, 0x28)
gef➤ p module
$8 = {
<torch::jit::Object> = {
_ivalue_ = {
target_ = 0x555558ab3b70
}
},
members of torch::jit::Module:
mem_to_delete_ = std::shared_ptr<char> (empty) = {
get() = 0x0
},
traced_inputs_ = {
impl_ = {
target_ = 0x0
}
},
register_mutex_ = std::shared_ptr<std::mutex> (empty) = {
get() = 0x0
}
}
gef➤ call (void)module.dump(true,true,true)
module __torch__.TinyCNN {
parameters {
}
attributes {
training = True
_is_full_backward_hook = None
network = <__torch__.torch.nn.modules.container.Sequential object at 0x555558a9f430>
}
methods {
method forward {
graph(%self.1 : __torch__.TinyCNN,
%x.1 : Tensor):
```
Object is an instance of `c10::ivalue::Object`
```clike
gef➤ ptype module._ivalue_.target_
type = struct c10::ivalue::Object : public c10::intrusive_ptr_target {
private:
...
```
### `c10::ivalue::Object`
According to [ivalue_inl.h](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/ivalue_inl.h#L1511), the `Object` structure contains a `type_` and a `slots_` vector.
```clike
struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
// ...
private:
WeakOrStrongTypePtr type_;
std::vector<IValue> slots_;
// ...
```
And the data structure of `Object` is:
```
0x00-0x10 c10::intrusive_ptr_target
0x10-0x50 type_
0x50-0x68 slots_
```
My analysis confirmed that `slots_` acts as a vector storing the children of a node (submodules or weights).
### `c10::IValue`
Each element in `slots_` is a `c10::IValue`, which contains a payload(pointer), a tag (defining the type, e.g., `Tensor`, `Int`, or `Object`), and a `kNumTags`.
```clike
static const int kNumTags;
c10::IValue::Payload payload;
c10::IValue::Tag tag;
```
The list of `c10::IValue::Tag` is in [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/ivalue.h).
```clike
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) \
_(Storage) \
_(Double) \
_(ComplexDouble) \
_(Int) \
_(UInt) \
_(SymInt) \
_(SymFloat) \
_(SymBool) \
_(Bool) \
_(Tuple) \
_(String) \
_(Blob) \
_(GenericList) \
_(GenericDict) \
_(Future) \
_(Await) \
_(Device) \
_(Stream) \
_(Object) \
_(PyObject) \
_(Uninitialized) \
_(Capsule) \
_(RRef) \
_(Quantizer) \
_(Generator) \
_(Enum)
```
The data structure of `c10::IValue` is
```
0x00-0x08 payload
0x08-0x0c tag
0x0c-0x10 kNumTags
```
### `at::Tensor`
By comparing different tensor allocations (e.g., $2\times 3$ zeros vs. $4\times 5$ ones), I mapped the binary structure of `at::Tensor`:
The structure of `at::Tensor` is
```
0x00-0x10 ?
0x10-0x20 Address B
0x20-0x40 ?
0x40-0x48 Dimension of the tensor
0x48-0x48+Dimension*0x08 shapes
...
0xa8-0xa9 ScalarType
```
```
B: 0x00-0x10 ?
B: 0x10-0x20 Address of the weight arrays
```
### What we got now
So far, we can correctly extract weights from the memory. The next step is to find the information of each layers, such as name and code.
### Locating Layer Metadata
I identified that the layer name (e.g., Conv2d) is not stored directly in the object. By systematically corrupting memory offsets and observing when `module.dump()` failed, I located the layer name pointer in the Object.
To identify where the layer names (e.g., `Sequential` vs. `Conv2d`) are stored, I performed a comparative analysis of two different submodule objects from a simple model:
```python
class TinyCNN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(3, 10, kernel_size=1, bias=False)
)
with torch.no_grad():
self.network[0].weight.zero_()
def forward(self, x):
logits = self.network(x)
return logits
```
The Sequential and the Conv2d layer have distinct internal names: `__torch__.torch.nn.modules.container.Sequential` and `__torch__.torch.nn.modules.conv.Conv2d`. Consequently, any memory offset containing the same value in both objects could be ruled out as the location for the class name.
```
gef➤ x/13gx 0x555558a9f430 // Sequential's object
0x555558a9f430: 0x0000555555583848 0x0000000100000001 // same
0x555558a9f440: 0x00005555588e60b0 0x00005555588e60a0 // same
0x555558a9f450: 0x3f00000000000401 0x0000000000000000
0x555558a9f460: 0x0000555558a9ae30 0x0000555558a9ae00
0x555558a9f470: 0x0000555558a798c0 0x0000555558a9d3e0
0x555558a9f480: 0x0000555558a9a7a0 0x0000555558a9a7d0 // slots_
0x555558a9f490: 0x0000555558a9a7d0 // slots_
gef➤ x/13gx 0x555558a39780 // Conv2d's object
0x555558a39780: 0x0000555555583848 0x0000000100000001 // same
0x555558a39790: 0x00005555588e60b0 0x00005555588e60a0 // same
0x555558a397a0: 0x0000555558ac1701 0x000000000000000b
0x555558a397b0: 0x0000555558abe260 0x0000555500000000
0x555558a397c0: 0x0000555558aba180 0x0000555558a94780
0x555558a397d0: 0x0000555558a2dbe0 0x0000555558a2dc40 // slots_
0x555558a397e0: 0x0000555558a2dc40 // slots_
```
My methodology involved manually corrupting specific memory offsets and then invoking the `module.dump()` function to observe if the output was affected. For instance, nulling out the value at offset 0x20 (address `0x555558a397a0`) did not prevent `dump()` from correctly identifying the layer:
```
gef➤ set {long}0x555558a397a0=0
gef➤ x/13gx 0x555558a39780
0x555558a39780: 0x0000555555583848 0x0000000100000001
0x555558a39790: 0x00005555588e60b0 0x00005555588e60a0
0x555558a397a0: 0x0000000000000000 0x000000000000000b
0x555558a397b0: 0x0000555558abe260 0x0000555500000000
0x555558a397c0: 0x0000555558aba180 0x0000555558a94780
0x555558a397d0: 0x0000555558a2dbe0 0x0000555558a2dc40
0x555558a397e0: 0x0000555558a2dc40
gef➤ call (void)module.dump(true,true,true)
module __torch__.TinyCNN {
parameters {
}
attributes {
training = True
...
submodules {
module __torch__.torch.nn.modules.conv.Conv2d {
...
```
After iterating through all offsets, I determined that only the value at offset `0x40` was critical. Modifying this value caused `dump()` to crash or return an error. This indicated that `0x40` is a pointer to a metadata structure.
By following the pointer at `0x40` and inspecting the destination memory, I discovered that the layer name string is stored at a further offset of `0x38` from that destination.
Based on these findings, the internal memory layout of `c10::ivalue::Object` can be mapped as follows:
```
0x00-0x10 c10::intrusive_ptr_target
0x10-0x40 ?
0x40-0x48 -> addr_0, addr_0+0x38 -> name of the layer
0x48-0x50 ?
0x50-0x68 slots_
```
### The last step, finding the code
During testing, I observed Python-like traceback messages in the C++ environment. This suggested that the serialized TorchScript code remains in memory. By dumping the heap using Volatility 3 and running `strings` command, I successfully recovered class definitions and forward-pass logic for the layers.
Traceback messag:
```
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/torch/nn/modules/pooling.py", line 15, in forward
input: Tensor) -> Tensor:
_0 = __torch__.torch.nn.functional._max_pool2d
_1 = _0(input, [1234, 1234], [5678, 5678], [0, 0], [1, 1], False, False, )
~~ <--- HERE
```
class definitions:
```
class Conv2d(Module):
__parameters__ = ["weight", "bias", ]
__buffers__ = []
weight : Tensor
bias : Optional[Tensor]
training : bool
_is_full_backward_hook : NoneType
transposed : bool
_reversed_padding_repeated_twice : List[int]
padding : Final[Tuple[int, int]] = (1, 1)
out_channels : Final[int] = 128
stride : Final[Tuple[int, int]] = (1, 1)
dilation : Final[Tuple[int, int]] = (1, 1)
in_channels : Final[int] = 128
padding_mode : Final[str] = "zeros"
groups : Final[int] = 1
output_padding : Final[Tuple[int, int]] = (0, 0)
kernel_size : Final[Tuple[int, int]] = (3, 3)
def forward(self: __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d,
input: Tensor) -> Tensor:
weight = self.weight
bias = self.bias
_0 = (self)._conv_forward(input, weight, bias, )
return _0
def _conv_forward(self: __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d,
input: Tensor,
weight: Tensor,
bias: Optional[Tensor]) -> Tensor:
_1 = torch.conv2d(input, weight, bias, [1, 1], [1, 1], [1, 1])
return _1
```
I can identify the class definition for each layer by searching for specific patterns, such as `class Conv2d` or the method signature `def forward(self: __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d`. If a class invokes external functions, I can locate their definitions by searching for the string `def {function name}` within the memory dump.
By combining these techniques, I can develop a program to systematically extract the model's weights, architecture, and source code directly from a memory image.
## 5. Take a memory image of system with LiME
> This section can be skipped.
1. Clone LiME repo

2. Compiled it

3. Run the program with gdb

4. Set break point at ` at::Tensor output = module.forward(inputs).toTensor();` and paused at there.

5. Print PID and the address of `module` object.

PID: 5019
Address: 0x7fffffffdc80
6. Use LiME to take a image of the system

## 6. Use volatility3 to investigate the image
To ensure compatibility with modern kernels, I used Volatility 3.
1. Setup a python virtual environment and install volatility3
> It seems that volatility3 2.26.2 does not support the kernel of my system while 2.11.0 does.
```shell
sudo insmod ./lime-6.14.0-33-generic.ko "path=/home/vboxuser/Desktop/image.lime format=lime"
```

2. Use [dwarf2json](https://github.com/volatilityfoundation/dwarf2json) to make a json that help volatility parse the image.
Different linux kernel will have different behavior while arranging memory, so I need to generate a guide for volatility3.
Install dwarf2json

Install debug symbol package. [reference](https://documentation.ubuntu.com/server/explanation/debugging/debug-symbol-packages/)
```shell
sudo apt install linux-image-$(uname -r)-dbgsym
```

Use dwarf2json
```shell
sudo ./dwarf2json linux --elf /usr/lib/debug/boot/vmlinux-6.14.0-33-generic > ~/cyfi/ml/symbols/isf.json
```

3. Assign the symbol's directory while calling volatility.
```shell
vol -f ./memdump/image.lime -s ./symbols linux.pslist.PsList
```

Our process appears

We can use `linux.proc.Maps` to dump the process's memory
```shell
vol -f ./memdump/image.lime -s ./symbols linux.proc.Maps --pid 5019 --dump
```

## 7. Implement a volitility3 plugin
I developed a custom plugin consisting of four modules:
1. `ml.py`: The entry point, connect to Volatility 3.
2. `memreader.py`: A utility layer for reading primitive types and searching for strings within the memory image.
3. `moduleparser.py`: The core logic that traverses the `Object` tree and extracts tensor weights.
4. `pythonparser.py`: A validation tool to identify and extract valid Python code fragments from the heap.
<!-- vvv need polish -->
### `ml.py`
To implement a Volatility 3 plugin, the class must inherit from `plugins.PluginInterface`.
```python
class Dump(plugins.PluginInterface):
```
#### Defining Requirements
The first step is to define the plugin's requirements, such as arguments, versions, supported kernels, and CPU architectures, using the `get_requirements` class method. This ensures the environment is compatible before execution.
```python
@classmethod
def get_requirements(cls):
# Since we're calling the plugin, make sure we have the plugin's requirements
return [
# linux intel cpu
requirements.ModuleRequirement(
name="kernel",
description="Linux kernel",
architectures=["Intel32", "Intel64"],
),
# this plugin uses proc.Maps and pslist.PsList plugins
requirements.PluginRequirement(
name="proc", plugin=proc.Maps, version=(1, 0, 0)
),
requirements.PluginRequirement(
name="pslist", plugin=pslist.PsList, version=(2, 0, 0)
),
# two required arguments for the task
requirements.ListRequirement(
name="pid",
description="PID of the process containing ML model",
element_type=int,
optional=False,
),
requirements.ListRequirement(
name="target",
description="Process virtual memory address of ",
element_type=int,
optional=False,
),
]
```
#### `run()` and `_generator()`
The `run()` method serves as the entry point for the plugin logic. In Volatility 3, the convention is to calls a private generator, `_generator()`, to yield data row by row.
```python
def run(self):
# get tasks to do
return renderers.TreeGrid(
[
("column1", str),
("column2", int),
# ...
],
self._generator(tasks)
)
def _generator(self, tasks):
# do tasks one by one
for task in tasks:
# do task
yield (0, ("string", 1))
```
The final structure of the plugin looks like:
```python
class Dump(plugins.PluginInterface):
_required_framework_version = (2, 0, 0)
_version = (1, 0, 0)
@classmethod
def get_requirements(cls):
return [
# ...
]
def run(self):
# get tasks to do
return renderers.TreeGrid(
[
("column1", str),
("column2", int),
# ...
],
self._generator(tasks)
)
def _generator(self, tasks):
# do tasks one by one
for task in tasks:
# do task
yield (0, ("string", 1))
```
By following this architecture, I can easily integrate custom memory-parsing utilities into the Volatility 3.
### `memreader.py`
This module provides low-level memory abstraction utilities. It includes functions for reading primitive data types from specific addresses (e.g., "read address A as an integer" or "read address A as a string") and scanning the process memory space for specific string patterns or signatures.
### `moduleparser.py`
This is the core engine of the plugin, implementing the structural findings of the research. It focuses on three primary Torch classes:
* `torch::jit::script::Module`: Acts as the entry point; the parser locates the memory address storing the root Object.
* `at::Tensor`: Handles the extraction of weight sequences, including their dimensions (shape) and data types.
* `c10::ivalue::Object`: The most complex component, as it represents the model hierarchy. It is parsed recursively to extract the layer names and all nested tensors and sub-objects.
### `pythonparser.py`
Since memory space often contains fragmented or incomplete code blocks, this module validates whether a recovered code snippet is syntactically valid and complete. Additionally, it performs static analysis on a piece of code to identify all internal function calls, allowing the plugin to recursively fetch and reconstruct as much of the original source code as possible.
## 8. Result
To execute the plugin, use the `-p` flag in Volatility 3 to specify the plugin path. The tool generates three distinct output files:
* `weights`: Contains the numerical weights and the corresponding layer order.
* `classes`: Contains the recovered class definitions for each layer.
* `code`: Contains the logic for all functions invoked within those classes.
### layers
Actual models
```python
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
nn.Flatten(),
nn.Linear(256*4*4, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 10)
```
Detected models
```
```
Actual weights(head)
```
submodules {
module __torch__.torch.nn.modules.conv.Conv2d {
parameters {
weight = (1,1,.,.) =
-0.1128 -0.0541 -0.1705
-0.1345 0.2090 -0.1443
0.1323 -0.1142 -0.1106
(2,1,.,.) =
0.2198 -0.0825 0.0201
0.0165 -0.0673 -0.0595
0.0388 0.0697 -0.0804
(3,1,.,.) =
-0.2115 -0.1524 -0.0191
-0.2011 -0.0481 0.2243
0.2018 0.1832 0.1267
(4,1,.,.) =
0.1567 0.2159 0.1503
-0.0798 0.0739 -0.2365
-0.0417 -0.0939 -0.0403
(5,1,.,.) =
-0.1408 -0.1113 -0.0857
-0.0494 0.1020 0.1397
-0.1863 0.0569 0.0144
(6,1,.,.) =
-0.1130 0.0497 0.1853
0.1164 -0.1289 0.2558
0.1422 -0.1365 -0.0355
(7,1,.,.) =
0.1117 -0.0246 -0.0573
-0.2087 0.2120 -0.0216
-0.0935 0.1135 0.0671
```
### weights
detected weights(head)
```
array([[[[-1.12812236e-01, -5.41429818e-02, -1.70484200e-01], <
[-1.34493515e-01, 2.08971888e-01, -1.44343957e-01], <
[ 1.32343054e-01, -1.14226587e-01, -1.10565022e-01]], <
[[ 3.60756330e-02, -1.61835909e-01, 1.69146992e-02],
[ 8.53613168e-02, 1.08232848e-01, -2.01811772e-02],
[-5.32550551e-02, 1.82751402e-01, 9.52708945e-02]],
[[ 5.08270860e-02, -1.02460124e-01, -5.10557368e-02],
[-7.26676211e-02, 2.72562802e-01, 7.36015141e-02],
[-1.90690652e-01, 2.03315616e-01, -4.98034712e-03]]],
[[[ 2.19827488e-01, -8.25090557e-02, 2.01167762e-02], <
[ 1.65425371e-02, -6.73378482e-02, -5.95466904e-02], <
[ 3.87987942e-02, 6.96573928e-02, -8.03895444e-02]], <
[[ 7.39367679e-02, -1.30136073e-01, -1.28461584e-01],
[-1.04819298e-01, 1.00073097e-02, 2.47885793e-01],
[ 2.88021415e-02, 1.09791435e-01, 9.57169477e-03]],
[[ 5.36225401e-02, -1.17352903e-02, 3.24997976e-02],
[-1.33580580e-01, -1.58388242e-01, 2.30599642e-01],
[-1.23156495e-02, -1.24895319e-01, 2.09244471e-02]]],
[[[-2.11464942e-01, -1.52358100e-01, -1.91372670e-02], <
[-2.01056227e-01, -4.81198095e-02, 2.24323779e-01], <
[ 2.01821521e-01, 1.83233574e-01, 1.26737714e-01]], <
[[-1.87800646e-01, 4.02026623e-02, -1.27506450e-01],
[ 6.96158260e-02, 1.31613061e-01, -6.62343483e-03],
[ 1.11682549e-01, 1.19961284e-01, -1.66979969e-01]],
[[-1.52641624e-01, -9.55462009e-02, 4.66656219e-03],
[-1.78928301e-01, 2.12789237e-01, 2.07052663e-01],
[ 3.24359797e-02, 3.32375504e-02, -1.88215166e-01]]],
[[[ 1.56667590e-01, 2.15925068e-01, 1.50287077e-01], <
[-7.97597021e-02, 7.38974810e-02, -2.36540243e-01], <
[-4.16917540e-02, -9.39330980e-02, -4.03037965e-02]], <
[[ 1.65019497e-01, 2.32077554e-01, -1.42064080e-01],
[-4.93854024e-02, 1.31639289e-02, 1.11834696e-02],
[-1.98959596e-02, 2.84764543e-02, -1.55483663e-01]],
```
Actual weights(tail)
```
Columns 511 to 512 8.4559e-02 -3.2748e-02
-6.9094e-02 -1.3549e-02
-8.2849e-02 1.0276e-02
-5.1112e-02 -1.1264e-02
-1.1754e-01 6.3199e-03
-3.4535e-02 -4.1833e-02
1.5253e-01 -2.6820e-02
-1.0617e-01 1.5463e-02
-3.6679e-02 -3.4552e-02
3.5192e-02 -2.3500e-02
```
Detected weights
```
array([[ 0.00237897, 0.04379057, -0.00424963, ..., -0.00897786,
0.08455919, -0.03274794],
[ 0.01760593, 0.04332345, 0.00937643, ..., 0.03398303,
-0.06909435, -0.01354914],
[-0.03963586, 0.04825915, 0.00400455, ..., 0.03685457,
-0.08284932, 0.01027638],
...,
[-0.03744305, -0.09314314, 0.01747996, ..., 0.03204723,
-0.10616972, 0.01546291],
^^^^^^^^^^^^^^^^^^^^^^^^^
[-0.03020638, 0.04046746, -0.01228058, ..., -0.01300872,
-0.0366788 , -0.03455176],
^^^^^^^^^^^^^^^^^^^^^^^^^
[-0.00629291, 0.04655466, 0.00656728, ..., -0.01035435,
0.03519187, -0.02350029]], shape=(10, 512), dtype=float32)
^^^^^^^^^^^^^^^^^^^^^^^^
```
### code
Some class definitions
```python
####################
class Conv2d(Module):
__parameters__ = ["weight", "bias", ]
__buffers__ = []
weight : Tensor
bias : Optional[Tensor]
training : bool
_is_full_backward_hook : NoneType
transposed : bool
_reversed_padding_repeated_twice : List[int]
padding : Final[Tuple[int, int]] = (1, 1)
out_channels : Final[int] = 64
stride : Final[Tuple[int, int]] = (1, 1)
dilation : Final[Tuple[int, int]] = (1, 1)
in_channels : Final[int] = 32
padding_mode : Final[str] = "zeros"
groups : Final[int] = 1
output_padding : Final[Tuple[int, int]] = (0, 0)
kernel_size : Final[Tuple[int, int]] = (3, 3)
def forward(self: __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d,
input: Tensor) -> Tensor:
weight = self.weight
bias = self.bias
_0 = (self)._conv_forward(input, weight, bias, )
return _0
def _conv_forward(self: __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d,
input: Tensor,
weight: Tensor,
bias: Optional[Tensor]) -> Tensor:
_1 = torch.conv2d(input, weight, bias, [1, 1], [1, 1], [1, 1])
return _1
####################
class ReLU(Module):
__parameters__ = []
__buffers__ = []
training : bool
_is_full_backward_hook : NoneType
inplace : Final[bool] = False
def forward(self: __torch__.torch.nn.modules.activation.ReLU,
input: Tensor) -> Tensor:
_0 = __torch__.torch.nn.functional.relu(input, False, )
return _0
####################
class MaxPool2d(Module):
__parameters__ = []
__buffers__ = []
training : bool
_is_full_backward_hook : NoneType
padding : Final[int] = 0
return_indices : Final[bool] = False
ceil_mode : Final[bool] = False
stride : Final[int] = 2
dilation : Final[int] = 1
kernel_size : Final[int] = 2
def forward(self: __torch__.torch.nn.modules.pooling.MaxPool2d,
input: Tensor) -> Tensor:
_0 = __torch__.torch.nn.functional._max_pool2d
_1 = _0(input, [2, 2], [2, 2], [0, 0], [1, 1], False, False, )
return _1
```
Some functions' code
```python
####################
# torch.floor_divide
####################
# __torch__.torch.nn.functional.relu
def relu(input: Tensor,
inplace: bool=False) -> Tensor:
if inplace:
result = torch.relu_(input)
else:
result = torch.relu(input)
return result
####################
# torch.relu_
####################
# __torch__.torch.nn.functional._max_pool2d
def _max_pool2d(input: Tensor,
kernel_size: List[int],
stride: Optional[List[int]]=None,
padding: List[int]=[0, 0],
dilation: List[int]=[1, 1],
ceil_mode: bool=False,
return_indices: bool=False) -> Tensor:
if torch.__is__(stride, None):
stride0 = annotate(List[int], [])
else:
stride0 = unchecked_cast(List[int], stride)
_0 = torch.max_pool2d(input, kernel_size, stride0, padding, dilation, ceil_mode)
return _0
```
Certain function definitions could not be located within the heap segment. I have not yet investigated the cause of this, but I guess they are something like built-in functions, which is basic and primitive so no need to be stored.
## 9. Conclusion
This project demonstrates that it is possible to reconstruct a machine learning model's architecture, weights, and logic directly from a memory dump of a C++ application.
## 10. Future work
1. Develop a more robust way for locating code fragments without relying on string searching.
2. Implement an automated procedure to compare extracted models against original .pt files.
3. Rebuild an executable model based on the result.