```python! import torch import numpy as np from sklearn.metrics import roc_auc_score, confusion_matrix import pickle from scipy.io import loadmat path_saver = 'pytorch_models\\2023-09-24 23-02-39\yelp_PCGNN.pkl' data_name = "Amazon" thres = 0.5 top = 200 batch_size = 1024 def prob2pred(y_prob, thres=0.5): """ Convert probability to predicted results according to given threshold :param y_prob: numpy array of probability in [0, 1] :param thres: binary classification threshold, default 0.5 :returns: the predicted result with the same shape as y_prob """ y_pred = np.zeros_like(y_prob, dtype=np.int32) y_pred[y_prob >= thres] = 1 y_pred[y_prob < thres] = 0 return y_pred def test_pcgnn(test_cases, labels, model, batch_size, thres=0.5, top = 6000): """ Test the performance of PC-GNN and its variants :param test_cases: a list of testing node :param labels: a list of testing node labels :param model: the GNN model :param batch_size: number nodes in a batch :returns: the AUC and Recall of GNN and Simi modules """ test_batch_num = int(len(test_cases) / batch_size) + 1 gnn_pred_list = [] gnn_prob_list = [] label_list1 = [] for iteration in range(test_batch_num): i_start = iteration * batch_size i_end = min((iteration + 1) * batch_size, len(test_cases)) batch_nodes = test_cases[i_start:i_end] batch_label = labels[i_start:i_end] gnn_prob, label_prob1 = model.to_prob(batch_nodes, batch_label, train_flag=False) gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1] gnn_pred = prob2pred(gnn_prob_arr, thres) gnn_pred_list.extend(gnn_pred.tolist()) gnn_prob_list.extend(gnn_prob_arr.tolist()) label_list1.extend(label_prob1.data.cpu().numpy()[:, 1].tolist()) conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list)) tn, fp, fn, tp = conf_gnn.ravel() # gmean_gnn = conf_gmean(conf_gnn) sorted_gnn_prob_list = list(gnn_prob_list) sorted_labels_list = list(labels) combined = list(zip(sorted_gnn_prob_list, sorted_labels_list)) sorted_combined = sorted(combined, key=lambda x: x[0], reverse=True) sorted_gnn_prob_list, sorted_labels_list = zip(*sorted_combined) sorted_conf_gnn = confusion_matrix(sorted_labels_list[:top], prob2pred(np.array(sorted_gnn_prob_list[:top]), thres)) sorted_tn, sorted_fp, sorted_fn, sorted_tp = sorted_conf_gnn.ravel() sorted_auc_gnn = roc_auc_score(sorted_labels_list[:top], prob2pred(np.array(sorted_gnn_prob_list[:top]), thres)) return tn, fp, fn, tp, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn def conf_gmean(conf): tn, fp, fn, tp = conf.ravel() return (tp*tn/((tp+fn)*(tn+fp)))**0.5 def load_data( prefix='data/'): """ Load graph, feature, and label given dataset name :returns: home and single-relation graphs, feature, label """ data_file = loadmat(prefix + 'Amazon.mat') labels = data_file['label'].flatten() feat_data = data_file['features'].todense().A # load the preprocessed adj_lists with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file: homo = pickle.load(file) file.close() with open(prefix + 'amz_upu_adjlists.pickle', 'rb') as file: relation1 = pickle.load(file) file.close() with open(prefix + 'amz_usu_adjlists.pickle', 'rb') as file: relation2 = pickle.load(file) file.close() with open(prefix + 'amz_uvu_adjlists.pickle', 'rb') as file: relation3 = pickle.load(file) return [homo, relation1, relation2, relation3], feat_data, labels [homo, relation1, relation2, relation3], feat_data, labels = load_data(prefix='./data/') index = list(range(len(labels))) print(f'Run on {data_name}, postive/total num: {np.sum(labels)}/{len(labels)}') print(f"Classification threshold: {thres}") print(f"Feature dimension: {feat_data.shape[1]}") gnn_model = torch.load(path_saver) tn, fp, fn, tp, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_pcgnn(index, labels, gnn_model, batch_size = 1024, thres = 0.5, top = top) print("------------------------------------------------") print("tp: {}".format(tp),"fn: {}".format(fn)) print("fp: {}".format(fp),"tn: {}".format(tn)) print("Precision: {}".format(tp/(tp+fp))) print("Recall: {}".format( tp/(tp+fn))) print("For @2000-------------------------") print("sorted_tp: {}".format(sorted_tp),"sorted_fn: {}".format(sorted_fn)) print("sorted_fp: {}".format(sorted_fp),"sorted_tn: {}".format(sorted_tn)) print("sorted_auc_gnn: {}".format(sorted_auc_gnn)) print("sorted_Precision: {}".format(sorted_tp/(sorted_tp+sorted_fp))) print("sorted_Recall: {}".format( sorted_tp/(sorted_tp+sorted_fn))) ```