```python!
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, confusion_matrix
import pickle
from scipy.io import loadmat
path_saver = 'pytorch_models\\2023-09-24 23-02-39\yelp_PCGNN.pkl'
data_name = "Amazon"
thres = 0.5
top = 200
batch_size = 1024
def prob2pred(y_prob, thres=0.5):
"""
Convert probability to predicted results according to given threshold
:param y_prob: numpy array of probability in [0, 1]
:param thres: binary classification threshold, default 0.5
:returns: the predicted result with the same shape as y_prob
"""
y_pred = np.zeros_like(y_prob, dtype=np.int32)
y_pred[y_prob >= thres] = 1
y_pred[y_prob < thres] = 0
return y_pred
def test_pcgnn(test_cases, labels, model, batch_size, thres=0.5, top = 6000):
"""
Test the performance of PC-GNN and its variants
:param test_cases: a list of testing node
:param labels: a list of testing node labels
:param model: the GNN model
:param batch_size: number nodes in a batch
:returns: the AUC and Recall of GNN and Simi modules
"""
test_batch_num = int(len(test_cases) / batch_size) + 1
gnn_pred_list = []
gnn_prob_list = []
label_list1 = []
for iteration in range(test_batch_num):
i_start = iteration * batch_size
i_end = min((iteration + 1) * batch_size, len(test_cases))
batch_nodes = test_cases[i_start:i_end]
batch_label = labels[i_start:i_end]
gnn_prob, label_prob1 = model.to_prob(batch_nodes, batch_label, train_flag=False)
gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1]
gnn_pred = prob2pred(gnn_prob_arr, thres)
gnn_pred_list.extend(gnn_pred.tolist())
gnn_prob_list.extend(gnn_prob_arr.tolist())
label_list1.extend(label_prob1.data.cpu().numpy()[:, 1].tolist())
conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list))
tn, fp, fn, tp = conf_gnn.ravel()
# gmean_gnn = conf_gmean(conf_gnn)
sorted_gnn_prob_list = list(gnn_prob_list)
sorted_labels_list = list(labels)
combined = list(zip(sorted_gnn_prob_list, sorted_labels_list))
sorted_combined = sorted(combined, key=lambda x: x[0], reverse=True)
sorted_gnn_prob_list, sorted_labels_list = zip(*sorted_combined)
sorted_conf_gnn = confusion_matrix(sorted_labels_list[:top], prob2pred(np.array(sorted_gnn_prob_list[:top]), thres))
sorted_tn, sorted_fp, sorted_fn, sorted_tp = sorted_conf_gnn.ravel()
sorted_auc_gnn = roc_auc_score(sorted_labels_list[:top], prob2pred(np.array(sorted_gnn_prob_list[:top]), thres))
return tn, fp, fn, tp, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn
def conf_gmean(conf):
tn, fp, fn, tp = conf.ravel()
return (tp*tn/((tp+fn)*(tn+fp)))**0.5
def load_data( prefix='data/'):
"""
Load graph, feature, and label given dataset name
:returns: home and single-relation graphs, feature, label
"""
data_file = loadmat(prefix + 'Amazon.mat')
labels = data_file['label'].flatten()
feat_data = data_file['features'].todense().A
# load the preprocessed adj_lists
with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file:
homo = pickle.load(file)
file.close()
with open(prefix + 'amz_upu_adjlists.pickle', 'rb') as file:
relation1 = pickle.load(file)
file.close()
with open(prefix + 'amz_usu_adjlists.pickle', 'rb') as file:
relation2 = pickle.load(file)
file.close()
with open(prefix + 'amz_uvu_adjlists.pickle', 'rb') as file:
relation3 = pickle.load(file)
return [homo, relation1, relation2, relation3], feat_data, labels
[homo, relation1, relation2, relation3], feat_data, labels = load_data(prefix='./data/')
index = list(range(len(labels)))
print(f'Run on {data_name}, postive/total num: {np.sum(labels)}/{len(labels)}')
print(f"Classification threshold: {thres}")
print(f"Feature dimension: {feat_data.shape[1]}")
gnn_model = torch.load(path_saver)
tn, fp, fn, tp, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_pcgnn(index, labels, gnn_model, batch_size = 1024, thres = 0.5, top = top)
print("------------------------------------------------")
print("tp: {}".format(tp),"fn: {}".format(fn))
print("fp: {}".format(fp),"tn: {}".format(tn))
print("Precision: {}".format(tp/(tp+fp)))
print("Recall: {}".format( tp/(tp+fn)))
print("For @2000-------------------------")
print("sorted_tp: {}".format(sorted_tp),"sorted_fn: {}".format(sorted_fn))
print("sorted_fp: {}".format(sorted_fp),"sorted_tn: {}".format(sorted_tn))
print("sorted_auc_gnn: {}".format(sorted_auc_gnn))
print("sorted_Precision: {}".format(sorted_tp/(sorted_tp+sorted_fp)))
print("sorted_Recall: {}".format( sorted_tp/(sorted_tp+sorted_fn)))
```