# main2.py
```python!
import argparse
import yaml
import torch
import time
import numpy as np
from collections import defaultdict, OrderedDict
from src.model_handler import ModelHandler
################################################################################
# Main #
################################################################################
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
def main(config):
print_config(config)
set_random_seed(config['seed'])
model = ModelHandler(config)
tn, fp, fn, tp, auc_test, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = model.train()
print("Testing results________________________________________________")
print("tn: ",tp, "fp: ",fp, "fn: ",fn , "tp: ",tp)
print("Accuracy: ", (sorted_tp+sorted_tn)/(sorted_tn + sorted_fp + sorted_fn + sorted_tp) )
print("Precision: ", sorted_tp/(sorted_tp+sorted_fp))
# print("Recall: ", sorted_tp/(sorted_tp+sorted_fn))
print("AUC: {}".format(sorted_auc_gnn))
################################################################################
# ArgParse and Helper Functions #
################################################################################
def get_config(config_path="config.yml"):
with open(config_path, "r") as setting:
config = yaml.load(setting, Loader=yaml.FullLoader)
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
parser.add_argument('--multi_run', action='store_true', help='flag: multi run')
args = vars(parser.parse_args())
return args
def print_config(config):
print("**************** MODEL CONFIGURATION ****************")
for key in sorted(config.keys()):
val = config[key]
keystr = "{}".format(key) + (" " * (24 - len(key)))
print("{} --> {}".format(keystr, val))
print("**************** MODEL CONFIGURATION ****************")
def grid(kwargs):
"""Builds a mesh grid with given keyword arguments for this Config class.
If the value is not a list, then it is considered fixed"""
class MncDc:
"""This is because np.meshgrid does not always work properly..."""
def __init__(self, a):
self.a = a # tuple!
def __call__(self):
return self.a
def merge_dicts(*dicts):
"""
Merges dictionaries recursively. Accepts also `None` and returns always a (possibly empty) dictionary
"""
from functools import reduce
def merge_two_dicts(x, y):
z = x.copy() # start with x's keys and values
z.update(y) # modifies z with y's keys and values & returns None
return z
return reduce(lambda a, nd: merge_two_dicts(a, nd if nd else {}), dicts, {})
sin = OrderedDict({k: v for k, v in kwargs.items() if isinstance(v, list)})
for k, v in sin.items():
copy_v = []
for e in v:
copy_v.append(MncDc(e) if isinstance(e, tuple) else e)
sin[k] = copy_v
grd = np.array(np.meshgrid(*sin.values()), dtype=object).T.reshape(-1, len(sin.values()))
return [merge_dicts(
{k: v for k, v in kwargs.items() if not isinstance(v, list)},
{k: vv[i]() if isinstance(vv[i], MncDc) else vv[i] for i, k in enumerate(sin)}
) for vv in grd]
################################################################################
# Module Command-line Behavior #
################################################################################
if __name__ == '__main__':
cfg = get_args()
config = get_config(cfg['config'])
main(config)
```
# model_handler.py
```python!
import time, datetime
import os
import random
import argparse
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from src.utils import test_pcgnn, test_sage, load_data, pos_neg_split, normalize, pick_step
from src.model import PCALayer
from src.layers import InterAgg, IntraAgg
from src.graphsage import *
"""
Training PC-GNN
Paper: Pick and Choose: A GNN-based Imbalanced Learning Approach for Fraud Detection
"""
class ModelHandler(object):
def __init__(self, config):
args = argparse.Namespace(**config)
# load graph, feature, and label
[homo, relation1, relation2, relation3], feat_data, labels = load_data(args.data_name, prefix=args.data_dir)
# train_test split
np.random.seed(args.seed)
random.seed(args.seed)
if args.data_name == 'yelp':
index = list(range(len(labels)))
idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels, stratify=labels, train_size=args.train_ratio,
random_state=2, shuffle=True)
idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio,
random_state=2, shuffle=True)
elif args.data_name == 'amazon': # amazon
# 0-3304 are unlabeled nodes
index = list(range(3305, len(labels)))
idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[3305:], stratify=labels[3305:],
train_size=args.train_ratio, random_state=2, shuffle=True)
idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest,
test_size=args.test_ratio, random_state=2, shuffle=True)
print(f'Run on {args.data_name}, postive/total num: {np.sum(labels)}/{len(labels)}, train num {len(y_train)},'+
f'valid num {len(y_valid)}, test num {len(y_test)}, test positive num {np.sum(y_test)}')
print(f"Classification threshold: {args.thres}")
print(f"Feature dimension: {feat_data.shape[1]}")
# split pos neg sets for under-sampling
train_pos, train_neg = pos_neg_split(idx_train, y_train)
# if args.data == 'amazon':
feat_data = normalize(feat_data)
# train_feats = feat_data[np.array(idx_train)]
# scaler = StandardScaler()
# scaler.fit(train_feats)
# feat_data = scaler.transform(feat_data)
args.cuda = not args.no_cuda and torch.cuda.is_available()
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_id
# set input graph
if args.model == 'SAGE' or args.model == 'GCN':
adj_lists = homo
else:
adj_lists = [relation1, relation2, relation3]
print(f'Model: {args.model}, multi-relation aggregator: {args.multi_relation}, emb_size: {args.emb_size}.')
self.args = args
self.dataset = {'feat_data': feat_data, 'labels': labels, 'adj_lists': adj_lists, 'homo': homo,
'idx_train': idx_train, 'idx_valid': idx_valid, 'idx_test': idx_test,
'y_train': y_train, 'y_valid': y_valid, 'y_test': y_test,
'train_pos': train_pos, 'train_neg': train_neg}
def train(self):
args = self.args
feat_data, adj_lists = self.dataset['feat_data'], self.dataset['adj_lists']
idx_train, y_train = self.dataset['idx_train'], self.dataset['y_train']
idx_valid, y_valid, idx_test, y_test = self.dataset['idx_valid'], self.dataset['y_valid'], self.dataset['idx_test'], self.dataset['y_test']
# initialize model input
features = nn.Embedding(feat_data.shape[0], feat_data.shape[1])
features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
if args.cuda:
features.cuda()
# build one-layer models
if args.model == 'PCGNN':
intra1 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], args.rho, cuda=args.cuda)
intra2 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], args.rho, cuda=args.cuda)
intra3 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], args.rho, cuda=args.cuda)
inter1 = InterAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'],
adj_lists, [intra1, intra2, intra3], inter=args.multi_relation, cuda=args.cuda)
elif args.model == 'SAGE':
agg_sage = MeanAggregator(features, cuda=args.cuda)
enc_sage = Encoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_sage, gcn=False, cuda=args.cuda)
elif args.model == 'GCN':
agg_gcn = GCNAggregator(features, cuda=args.cuda)
enc_gcn = GCNEncoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_gcn, gcn=True, cuda=args.cuda)
if args.model == 'PCGNN':
gnn_model = PCALayer(2, inter1, args.alpha)
elif args.model == 'SAGE':
# the vanilla GraphSAGE model as baseline
enc_sage.num_samples = 5
gnn_model = GraphSage(2, enc_sage)
elif args.model == 'GCN':
gnn_model = GCN(2, enc_gcn)
if args.cuda:
gnn_model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gnn_model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
timestamp = time.time()
timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S')
dir_saver = args.save_dir+timestamp
path_saver = os.path.join(dir_saver, '{}_{}.pkl'.format(args.data_name, args.model))
f1_mac_best, auc_best, ep_best, precisions_best = 0, 0, -1, 0
# train the model
for epoch in range(args.num_epochs):
sampled_idx_train = pick_step(idx_train, y_train, self.dataset['homo'], size=len(self.dataset['train_pos'])*2)
random.shuffle(sampled_idx_train)
num_batches = int(len(sampled_idx_train) / args.batch_size) + 1
loss = 0.0
epoch_time = 0
# mini-batch training
for batch in range(num_batches):
start_time = time.time()
i_start = batch * args.batch_size
i_end = min((batch + 1) * args.batch_size, len(sampled_idx_train))
batch_nodes = sampled_idx_train[i_start:i_end]
batch_label = self.dataset['labels'][np.array(batch_nodes)]
optimizer.zero_grad()
if args.cuda:
loss = gnn_model.loss(batch_nodes, Variable(torch.cuda.LongTensor(batch_label)))
else:
loss = gnn_model.loss(batch_nodes, Variable(torch.LongTensor(batch_label)))
loss.backward()
optimizer.step()
end_time = time.time()
epoch_time += end_time - start_time
loss += loss.item()
print(f'Epoch: {epoch}, loss: {loss.item() / num_batches}, time: {epoch_time}s')
# Valid the model for every $valid_epoch$ epoch
if epoch % args.valid_epochs == 0:
if args.model == 'SAGE' or args.model == 'GCN':
print("Valid at epoch {}".format(epoch))
# f1_mac_val, f1_1_val, f1_0_val, auc_val, gmean_val = test_sage(idx_valid, y_valid, gnn_model, args.batch_size, args.thres)
tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_sage(idx_valid, y_valid, gnn_model, args.batch_size, args.thres)
precisions_val = tp/(tp+fp)
if precisions_val > precisions_best:
ep_best, precisions_best = epoch, precisions_val
if not os.path.exists(dir_saver):
os.makedirs(dir_saver)
print('Saving model ...',args.model)
torch.save(gnn_model.state_dict(), path_saver)
else:
print("Valid at epoch {}".format(epoch))
tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_pcgnn(idx_valid, y_valid, gnn_model, args.batch_size, args.thres)
precisions_val = tp/(tp+fp)
if precisions_val > precisions_best:
ep_best, precisions_best = epoch, precisions_val
if not os.path.exists(dir_saver):
os.makedirs(dir_saver)
print(' Saving model ...', args.model)
torch.save(gnn_model.state_dict(), path_saver)
print("Restore model from epoch {}".format(ep_best))
print("Model path: {}".format(path_saver))
gnn_model.load_state_dict(torch.load(path_saver))
if args.model == 'SAGE' or args.model == 'GCN':
tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_sage(idx_test, y_test, gnn_model, args.batch_size, args.thres)
else:
tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_pcgnn(idx_test, y_test, gnn_model, args.batch_size, args.thres)
return tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn
```
# utils.py
```python!
import pickle
import random
import numpy as np
import scipy.sparse as sp
from scipy.io import loadmat
import copy as cp
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, average_precision_score, confusion_matrix
from collections import defaultdict
"""
Utility functions to handle data and evaluate model.
"""
def load_data(data, prefix='data/'):
"""
Load graph, feature, and label given dataset name
:returns: home and single-relation graphs, feature, label
"""
if data == 'yelp':
data_file = loadmat(prefix + 'YelpChi.mat')
labels = data_file['label'].flatten()
feat_data = data_file['features'].todense().A
# load the preprocessed adj_lists
with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
homo = pickle.load(file)
file.close()
with open(prefix + 'yelp_rur_adjlists.pickle', 'rb') as file:
relation1 = pickle.load(file)
file.close()
with open(prefix + 'yelp_rtr_adjlists.pickle', 'rb') as file:
relation2 = pickle.load(file)
file.close()
with open(prefix + 'yelp_rsr_adjlists.pickle', 'rb') as file:
relation3 = pickle.load(file)
file.close()
elif data == 'amazon':
data_file = loadmat(prefix + 'Amazon.mat')
labels = data_file['label'].flatten()
feat_data = data_file['features'].todense().A
# load the preprocessed adj_lists
with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file:
homo = pickle.load(file)
file.close()
with open(prefix + 'amz_upu_adjlists.pickle', 'rb') as file:
relation1 = pickle.load(file)
file.close()
with open(prefix + 'amz_usu_adjlists.pickle', 'rb') as file:
relation2 = pickle.load(file)
file.close()
with open(prefix + 'amz_uvu_adjlists.pickle', 'rb') as file:
relation3 = pickle.load(file)
return [homo, relation1, relation2, relation3], feat_data, labels
def normalize(mx):
"""
Row-normalize sparse matrix
Code from https://github.com/williamleif/graphsage-simple/
"""
rowsum = np.array(mx.sum(1)) + 0.01
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx
def sparse_to_adjlist(sp_matrix, filename):
"""
Transfer sparse matrix to adjacency list
:param sp_matrix: the sparse matrix
:param filename: the filename of adjlist
"""
# add self loop
homo_adj = sp_matrix + sp.eye(sp_matrix.shape[0])
# create adj_list
adj_lists = defaultdict(set)
edges = homo_adj.nonzero()
for index, node in enumerate(edges[0]):
adj_lists[node].add(edges[1][index])
adj_lists[edges[1][index]].add(node)
with open(filename, 'wb') as file:
pickle.dump(adj_lists, file)
file.close()
def pos_neg_split(nodes, labels):
"""
Find positive and negative nodes given a list of nodes and their labels
:param nodes: a list of nodes
:param labels: a list of node labels
:returns: the spited positive and negative nodes
"""
pos_nodes = []
neg_nodes = cp.deepcopy(nodes)
aux_nodes = cp.deepcopy(nodes)
for idx, label in enumerate(labels):
if label == 1:
pos_nodes.append(aux_nodes[idx])
neg_nodes.remove(aux_nodes[idx])
return pos_nodes, neg_nodes
def pick_step(idx_train, y_train, adj_list, size):
degree_train = [len(adj_list[node]) for node in idx_train]
lf_train = (y_train.sum()-len(y_train))*y_train + len(y_train)
smp_prob = np.array(degree_train) / lf_train
return random.choices(idx_train, weights=smp_prob, k=size)
def test_sage(test_cases, labels, model, batch_size, thres=0.5):
"""
Test the performance of GraphSAGE
:param test_cases: a list of testing node
:param labels: a list of testing node labels
:param model: the GNN model
:param batch_size: number nodes in a batch
"""
test_batch_num = int(len(test_cases) / batch_size) + 1
gnn_pred_list = []
gnn_prob_list = []
for iteration in range(test_batch_num):
i_start = iteration * batch_size
i_end = min((iteration + 1) * batch_size, len(test_cases))
batch_nodes = test_cases[i_start:i_end]
batch_label = labels[i_start:i_end]
gnn_prob = model.to_prob(batch_nodes)
gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1]
gnn_pred = prob2pred(gnn_prob_arr, thres)
gnn_pred_list.extend(gnn_pred.tolist())
gnn_prob_list.extend(gnn_prob_arr.tolist())
auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list))
# f1_binary_1_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=1, average='binary')
# f1_binary_0_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=0, average='binary')
# f1_micro_gnn = f1_score(labels, np.array(gnn_pred_list), average='micro')
# f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro')
conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list))
tn, fp, fn, tp = conf_gnn.ravel()
# gmean_gnn = conf_gmean(conf_gnn)
# print(f" GNN F1-binary-1: {f1_binary_1_gnn:.4f}\tF1-binary-0: {f1_binary_0_gnn:.4f}"+
# f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}")
# print(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}")
# return f1_macro_gnn, f1_binary_1_gnn, f1_binary_0_gnn, auc_gnn, gmean_gnn
sorted_gnn_prob_list = list(gnn_prob_list)
sorted_labels_list = list(labels)
combined = list(zip(sorted_gnn_prob_list, sorted_labels_list))
sorted_combined = sorted(combined, key=lambda x: x[0], reverse=True)
sorted_gnn_prob_list, sorted_labels_list = zip(*sorted_combined)
print("gnn_prob_list",sorted_gnn_prob_list[:10])
print("labels", sorted_labels_list[:10])
sorted_conf_gnn = confusion_matrix(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres))
sorted_tn, sorted_fp, sorted_fn, sorted_tp = sorted_conf_gnn.ravel()
# sorted_auc_gnn = roc_auc_score(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres))
print("sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn",sorted_tn, sorted_fp, sorted_fn, sorted_tp, auc_gnn)
return tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, auc_gnn
# return tn, fp, fn, tp, auc_gnn
def prob2pred(y_prob, thres=0.5):
"""
Convert probability to predicted results according to given threshold
:param y_prob: numpy array of probability in [0, 1]
:param thres: binary classification threshold, default 0.5
:returns: the predicted result with the same shape as y_prob
"""
y_pred = np.zeros_like(y_prob, dtype=np.int32)
y_pred[y_prob >= thres] = 1
y_pred[y_prob < thres] = 0
return y_pred
def test_pcgnn(test_cases, labels, model, batch_size, thres=0.5):
"""
Test the performance of PC-GNN and its variants
:param test_cases: a list of testing node
:param labels: a list of testing node labels
:param model: the GNN model
:param batch_size: number nodes in a batch
:returns: the AUC and Recall of GNN and Simi modules
"""
test_batch_num = int(len(test_cases) / batch_size) + 1
f1_gnn = 0.0
acc_gnn = 0.0
recall_gnn = 0.0
f1_label1 = 0.0
acc_label1 = 0.00
recall_label1 = 0.0
gnn_pred_list = []
gnn_prob_list = []
label_list1 = []
for iteration in range(test_batch_num):
i_start = iteration * batch_size
i_end = min((iteration + 1) * batch_size, len(test_cases))
batch_nodes = test_cases[i_start:i_end]
batch_label = labels[i_start:i_end]
gnn_prob, label_prob1 = model.to_prob(batch_nodes, batch_label, train_flag=False)
gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1]
gnn_pred = prob2pred(gnn_prob_arr, thres)
f1_label1 += f1_score(batch_label, label_prob1.data.cpu().numpy().argmax(axis=1), average="macro")
acc_label1 += accuracy_score(batch_label, label_prob1.data.cpu().numpy().argmax(axis=1))
recall_label1 += recall_score(batch_label, label_prob1.data.cpu().numpy().argmax(axis=1), average="macro")
gnn_pred_list.extend(gnn_pred.tolist())
gnn_prob_list.extend(gnn_prob_arr.tolist())
label_list1.extend(label_prob1.data.cpu().numpy()[:, 1].tolist())
auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list))
# ap_gnn = average_precision_score(labels, np.array(gnn_prob_list))
# auc_label1 = roc_auc_score(labels, np.array(label_list1))
# ap_label1 = average_precision_score(labels, np.array(label_list1))
# f1_binary_1_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=1, average='binary')
# f1_binary_0_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=0, average='binary')
# f1_micro_gnn = f1_score(labels, np.array(gnn_pred_list), average='micro')
# f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro')
conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list))
tn, fp, fn, tp = conf_gnn.ravel()
# gmean_gnn = conf_gmean(conf_gnn)
# print(f" GNN F1-binary-1: {f1_binary_1_gnn:.4f}\tF1-binary-0: {f1_binary_0_gnn:.4f}"+
# f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}")
# print(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}")
# print(f"Label1 F1: {f1_label1 / test_batch_num:.4f}\tAccuracy: {acc_label1 / test_batch_num:.4f}"+
# f"\tRecall: {recall_label1 / test_batch_num:.4f}\tAUC: {auc_label1:.4f}\tAP: {ap_label1:.4f}")
# return f1_macro_gnn, f1_binary_1_gnn, f1_binary_0_gnn, auc_gnn, gmean_gnn
sorted_gnn_prob_list = list(gnn_prob_list)
sorted_labels_list = list(labels)
combined = list(zip(sorted_gnn_prob_list, sorted_labels_list))
sorted_combined = sorted(combined, key=lambda x: x[0], reverse=True)
sorted_gnn_prob_list, sorted_labels_list = zip(*sorted_combined)
print("gnn_prob_list",sorted_gnn_prob_list[:10])
print("labels", sorted_labels_list[:10])
sorted_conf_gnn = confusion_matrix(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres))
sorted_tn, sorted_fp, sorted_fn, sorted_tp = sorted_conf_gnn.ravel()
sorted_auc_gnn = roc_auc_score(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres))
print("sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn",sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn)
return tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn
def conf_gmean(conf):
tn, fp, fn, tp = conf.ravel()
return (tp*tn/((tp+fn)*(tn+fp)))**0.5
```
# model.py
```python!
import torch
import torch.nn as nn
from torch.nn import init
"""
PC-GNN Model
Paper: Pick and Choose: A GNN-based Imbalanced Learning Approach for Fraud Detection
Modified from https://github.com/YingtongDou/CARE-GNN
"""
class PCALayer(nn.Module):
"""
One Pick-Choose-Aggregate layer
"""
def __init__(self, num_classes, inter1, lambda_1):
"""
Initialize the PC-GNN model
:param num_classes: number of classes (2 in our paper)
:param inter1: the inter-relation aggregator that output the final embedding
"""
super(PCALayer, self).__init__()
self.inter1 = inter1
self.xent = nn.CrossEntropyLoss()
# the parameter to transform the final embedding
self.weight = nn.Parameter(torch.FloatTensor(num_classes, inter1.embed_dim))
init.xavier_uniform_(self.weight)
self.lambda_1 = lambda_1
self.epsilon = 0.1
def forward(self, nodes, labels, train_flag=True):
embeds1, label_scores = self.inter1(nodes, labels, train_flag)
scores = self.weight.mm(embeds1)
return scores.t(), label_scores
def to_prob(self, nodes, labels, train_flag=True):
gnn_logits, label_logits = self.forward(nodes, labels, train_flag)
gnn_scores = torch.sigmoid(gnn_logits)
label_scores = torch.sigmoid(label_logits)
return gnn_scores, label_scores
# def loss(self, nodes, labels, train_flag=True):
# gnn_scores, label_scores = self.forward(nodes, labels, train_flag)
# # Simi loss, Eq. (7) in the paper
# label_loss = self.xent(label_scores, labels.squeeze())
# # GNN loss, Eq. (10) in the paper
# gnn_loss = self.xent(gnn_scores, labels.squeeze())
# # the loss function of PC-GNN, Eq. (11) in the paper
# final_loss = gnn_loss + self.lambda_1 * label_loss
# return final_loss
def focal_loss(self, input_probs, targets, alpha=0.5, gamma=2.0):
pt = torch.exp(-self.xent(input_probs, targets)) # 預測概率的指數部分
focal_term = (1 - pt) ** gamma # Focal Loss 中的調整項
alpha_term = alpha * (1 - pt) # 平衡因子與 Focal Loss 的結合
focal_loss = alpha_term * focal_term * self.xent(input_probs, targets)
return focal_loss.mean() # 平均 Focal Loss
def loss(self, nodes, labels, train_flag=True):
gnn_scores, label_scores = self.forward(nodes, labels, train_flag)
# Simi loss, Eq. (7) in the paper
label_loss = self.focal_loss(label_scores, labels.squeeze())
# GNN loss, Eq. (10) in the paper
gnn_loss = self.focal_loss(gnn_scores, labels.squeeze())
# the loss function of PC-GNN, Eq. (11) in the paper
final_loss = gnn_loss + self.lambda_1 * label_loss
return final_loss
```