# main2.py ```python! import argparse import yaml import torch import time import numpy as np from collections import defaultdict, OrderedDict from src.model_handler import ModelHandler ################################################################################ # Main # ################################################################################ def set_random_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) def main(config): print_config(config) set_random_seed(config['seed']) model = ModelHandler(config) tn, fp, fn, tp, auc_test, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = model.train() print("Testing results________________________________________________") print("tn: ",tp, "fp: ",fp, "fn: ",fn , "tp: ",tp) print("Accuracy: ", (sorted_tp+sorted_tn)/(sorted_tn + sorted_fp + sorted_fn + sorted_tp) ) print("Precision: ", sorted_tp/(sorted_tp+sorted_fp)) # print("Recall: ", sorted_tp/(sorted_tp+sorted_fn)) print("AUC: {}".format(sorted_auc_gnn)) ################################################################################ # ArgParse and Helper Functions # ################################################################################ def get_config(config_path="config.yml"): with open(config_path, "r") as setting: config = yaml.load(setting, Loader=yaml.FullLoader) return config def get_args(): parser = argparse.ArgumentParser() parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file') parser.add_argument('--multi_run', action='store_true', help='flag: multi run') args = vars(parser.parse_args()) return args def print_config(config): print("**************** MODEL CONFIGURATION ****************") for key in sorted(config.keys()): val = config[key] keystr = "{}".format(key) + (" " * (24 - len(key))) print("{} --> {}".format(keystr, val)) print("**************** MODEL CONFIGURATION ****************") def grid(kwargs): """Builds a mesh grid with given keyword arguments for this Config class. If the value is not a list, then it is considered fixed""" class MncDc: """This is because np.meshgrid does not always work properly...""" def __init__(self, a): self.a = a # tuple! def __call__(self): return self.a def merge_dicts(*dicts): """ Merges dictionaries recursively. Accepts also `None` and returns always a (possibly empty) dictionary """ from functools import reduce def merge_two_dicts(x, y): z = x.copy() # start with x's keys and values z.update(y) # modifies z with y's keys and values & returns None return z return reduce(lambda a, nd: merge_two_dicts(a, nd if nd else {}), dicts, {}) sin = OrderedDict({k: v for k, v in kwargs.items() if isinstance(v, list)}) for k, v in sin.items(): copy_v = [] for e in v: copy_v.append(MncDc(e) if isinstance(e, tuple) else e) sin[k] = copy_v grd = np.array(np.meshgrid(*sin.values()), dtype=object).T.reshape(-1, len(sin.values())) return [merge_dicts( {k: v for k, v in kwargs.items() if not isinstance(v, list)}, {k: vv[i]() if isinstance(vv[i], MncDc) else vv[i] for i, k in enumerate(sin)} ) for vv in grd] ################################################################################ # Module Command-line Behavior # ################################################################################ if __name__ == '__main__': cfg = get_args() config = get_config(cfg['config']) main(config) ``` # model_handler.py ```python! import time, datetime import os import random import argparse import numpy as np from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from src.utils import test_pcgnn, test_sage, load_data, pos_neg_split, normalize, pick_step from src.model import PCALayer from src.layers import InterAgg, IntraAgg from src.graphsage import * """ Training PC-GNN Paper: Pick and Choose: A GNN-based Imbalanced Learning Approach for Fraud Detection """ class ModelHandler(object): def __init__(self, config): args = argparse.Namespace(**config) # load graph, feature, and label [homo, relation1, relation2, relation3], feat_data, labels = load_data(args.data_name, prefix=args.data_dir) # train_test split np.random.seed(args.seed) random.seed(args.seed) if args.data_name == 'yelp': index = list(range(len(labels))) idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels, stratify=labels, train_size=args.train_ratio, random_state=2, shuffle=True) idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio, random_state=2, shuffle=True) elif args.data_name == 'amazon': # amazon # 0-3304 are unlabeled nodes index = list(range(3305, len(labels))) idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[3305:], stratify=labels[3305:], train_size=args.train_ratio, random_state=2, shuffle=True) idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio, random_state=2, shuffle=True) print(f'Run on {args.data_name}, postive/total num: {np.sum(labels)}/{len(labels)}, train num {len(y_train)},'+ f'valid num {len(y_valid)}, test num {len(y_test)}, test positive num {np.sum(y_test)}') print(f"Classification threshold: {args.thres}") print(f"Feature dimension: {feat_data.shape[1]}") # split pos neg sets for under-sampling train_pos, train_neg = pos_neg_split(idx_train, y_train) # if args.data == 'amazon': feat_data = normalize(feat_data) # train_feats = feat_data[np.array(idx_train)] # scaler = StandardScaler() # scaler.fit(train_feats) # feat_data = scaler.transform(feat_data) args.cuda = not args.no_cuda and torch.cuda.is_available() os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_id # set input graph if args.model == 'SAGE' or args.model == 'GCN': adj_lists = homo else: adj_lists = [relation1, relation2, relation3] print(f'Model: {args.model}, multi-relation aggregator: {args.multi_relation}, emb_size: {args.emb_size}.') self.args = args self.dataset = {'feat_data': feat_data, 'labels': labels, 'adj_lists': adj_lists, 'homo': homo, 'idx_train': idx_train, 'idx_valid': idx_valid, 'idx_test': idx_test, 'y_train': y_train, 'y_valid': y_valid, 'y_test': y_test, 'train_pos': train_pos, 'train_neg': train_neg} def train(self): args = self.args feat_data, adj_lists = self.dataset['feat_data'], self.dataset['adj_lists'] idx_train, y_train = self.dataset['idx_train'], self.dataset['y_train'] idx_valid, y_valid, idx_test, y_test = self.dataset['idx_valid'], self.dataset['y_valid'], self.dataset['idx_test'], self.dataset['y_test'] # initialize model input features = nn.Embedding(feat_data.shape[0], feat_data.shape[1]) features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False) if args.cuda: features.cuda() # build one-layer models if args.model == 'PCGNN': intra1 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], args.rho, cuda=args.cuda) intra2 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], args.rho, cuda=args.cuda) intra3 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], args.rho, cuda=args.cuda) inter1 = InterAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], adj_lists, [intra1, intra2, intra3], inter=args.multi_relation, cuda=args.cuda) elif args.model == 'SAGE': agg_sage = MeanAggregator(features, cuda=args.cuda) enc_sage = Encoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_sage, gcn=False, cuda=args.cuda) elif args.model == 'GCN': agg_gcn = GCNAggregator(features, cuda=args.cuda) enc_gcn = GCNEncoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_gcn, gcn=True, cuda=args.cuda) if args.model == 'PCGNN': gnn_model = PCALayer(2, inter1, args.alpha) elif args.model == 'SAGE': # the vanilla GraphSAGE model as baseline enc_sage.num_samples = 5 gnn_model = GraphSage(2, enc_sage) elif args.model == 'GCN': gnn_model = GCN(2, enc_gcn) if args.cuda: gnn_model.cuda() optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gnn_model.parameters()), lr=args.lr, weight_decay=args.weight_decay) timestamp = time.time() timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S') dir_saver = args.save_dir+timestamp path_saver = os.path.join(dir_saver, '{}_{}.pkl'.format(args.data_name, args.model)) f1_mac_best, auc_best, ep_best, precisions_best = 0, 0, -1, 0 # train the model for epoch in range(args.num_epochs): sampled_idx_train = pick_step(idx_train, y_train, self.dataset['homo'], size=len(self.dataset['train_pos'])*2) random.shuffle(sampled_idx_train) num_batches = int(len(sampled_idx_train) / args.batch_size) + 1 loss = 0.0 epoch_time = 0 # mini-batch training for batch in range(num_batches): start_time = time.time() i_start = batch * args.batch_size i_end = min((batch + 1) * args.batch_size, len(sampled_idx_train)) batch_nodes = sampled_idx_train[i_start:i_end] batch_label = self.dataset['labels'][np.array(batch_nodes)] optimizer.zero_grad() if args.cuda: loss = gnn_model.loss(batch_nodes, Variable(torch.cuda.LongTensor(batch_label))) else: loss = gnn_model.loss(batch_nodes, Variable(torch.LongTensor(batch_label))) loss.backward() optimizer.step() end_time = time.time() epoch_time += end_time - start_time loss += loss.item() print(f'Epoch: {epoch}, loss: {loss.item() / num_batches}, time: {epoch_time}s') # Valid the model for every $valid_epoch$ epoch if epoch % args.valid_epochs == 0: if args.model == 'SAGE' or args.model == 'GCN': print("Valid at epoch {}".format(epoch)) # f1_mac_val, f1_1_val, f1_0_val, auc_val, gmean_val = test_sage(idx_valid, y_valid, gnn_model, args.batch_size, args.thres) tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_sage(idx_valid, y_valid, gnn_model, args.batch_size, args.thres) precisions_val = tp/(tp+fp) if precisions_val > precisions_best: ep_best, precisions_best = epoch, precisions_val if not os.path.exists(dir_saver): os.makedirs(dir_saver) print('Saving model ...',args.model) torch.save(gnn_model.state_dict(), path_saver) else: print("Valid at epoch {}".format(epoch)) tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_pcgnn(idx_valid, y_valid, gnn_model, args.batch_size, args.thres) precisions_val = tp/(tp+fp) if precisions_val > precisions_best: ep_best, precisions_best = epoch, precisions_val if not os.path.exists(dir_saver): os.makedirs(dir_saver) print(' Saving model ...', args.model) torch.save(gnn_model.state_dict(), path_saver) print("Restore model from epoch {}".format(ep_best)) print("Model path: {}".format(path_saver)) gnn_model.load_state_dict(torch.load(path_saver)) if args.model == 'SAGE' or args.model == 'GCN': tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_sage(idx_test, y_test, gnn_model, args.batch_size, args.thres) else: tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn = test_pcgnn(idx_test, y_test, gnn_model, args.batch_size, args.thres) return tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn ``` # utils.py ```python! import pickle import random import numpy as np import scipy.sparse as sp from scipy.io import loadmat import copy as cp from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, average_precision_score, confusion_matrix from collections import defaultdict """ Utility functions to handle data and evaluate model. """ def load_data(data, prefix='data/'): """ Load graph, feature, and label given dataset name :returns: home and single-relation graphs, feature, label """ if data == 'yelp': data_file = loadmat(prefix + 'YelpChi.mat') labels = data_file['label'].flatten() feat_data = data_file['features'].todense().A # load the preprocessed adj_lists with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file: homo = pickle.load(file) file.close() with open(prefix + 'yelp_rur_adjlists.pickle', 'rb') as file: relation1 = pickle.load(file) file.close() with open(prefix + 'yelp_rtr_adjlists.pickle', 'rb') as file: relation2 = pickle.load(file) file.close() with open(prefix + 'yelp_rsr_adjlists.pickle', 'rb') as file: relation3 = pickle.load(file) file.close() elif data == 'amazon': data_file = loadmat(prefix + 'Amazon.mat') labels = data_file['label'].flatten() feat_data = data_file['features'].todense().A # load the preprocessed adj_lists with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file: homo = pickle.load(file) file.close() with open(prefix + 'amz_upu_adjlists.pickle', 'rb') as file: relation1 = pickle.load(file) file.close() with open(prefix + 'amz_usu_adjlists.pickle', 'rb') as file: relation2 = pickle.load(file) file.close() with open(prefix + 'amz_uvu_adjlists.pickle', 'rb') as file: relation3 = pickle.load(file) return [homo, relation1, relation2, relation3], feat_data, labels def normalize(mx): """ Row-normalize sparse matrix Code from https://github.com/williamleif/graphsage-simple/ """ rowsum = np.array(mx.sum(1)) + 0.01 r_inv = np.power(rowsum, -1).flatten() r_inv[np.isinf(r_inv)] = 0. r_mat_inv = sp.diags(r_inv) mx = r_mat_inv.dot(mx) return mx def sparse_to_adjlist(sp_matrix, filename): """ Transfer sparse matrix to adjacency list :param sp_matrix: the sparse matrix :param filename: the filename of adjlist """ # add self loop homo_adj = sp_matrix + sp.eye(sp_matrix.shape[0]) # create adj_list adj_lists = defaultdict(set) edges = homo_adj.nonzero() for index, node in enumerate(edges[0]): adj_lists[node].add(edges[1][index]) adj_lists[edges[1][index]].add(node) with open(filename, 'wb') as file: pickle.dump(adj_lists, file) file.close() def pos_neg_split(nodes, labels): """ Find positive and negative nodes given a list of nodes and their labels :param nodes: a list of nodes :param labels: a list of node labels :returns: the spited positive and negative nodes """ pos_nodes = [] neg_nodes = cp.deepcopy(nodes) aux_nodes = cp.deepcopy(nodes) for idx, label in enumerate(labels): if label == 1: pos_nodes.append(aux_nodes[idx]) neg_nodes.remove(aux_nodes[idx]) return pos_nodes, neg_nodes def pick_step(idx_train, y_train, adj_list, size): degree_train = [len(adj_list[node]) for node in idx_train] lf_train = (y_train.sum()-len(y_train))*y_train + len(y_train) smp_prob = np.array(degree_train) / lf_train return random.choices(idx_train, weights=smp_prob, k=size) def test_sage(test_cases, labels, model, batch_size, thres=0.5): """ Test the performance of GraphSAGE :param test_cases: a list of testing node :param labels: a list of testing node labels :param model: the GNN model :param batch_size: number nodes in a batch """ test_batch_num = int(len(test_cases) / batch_size) + 1 gnn_pred_list = [] gnn_prob_list = [] for iteration in range(test_batch_num): i_start = iteration * batch_size i_end = min((iteration + 1) * batch_size, len(test_cases)) batch_nodes = test_cases[i_start:i_end] batch_label = labels[i_start:i_end] gnn_prob = model.to_prob(batch_nodes) gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1] gnn_pred = prob2pred(gnn_prob_arr, thres) gnn_pred_list.extend(gnn_pred.tolist()) gnn_prob_list.extend(gnn_prob_arr.tolist()) auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list)) # f1_binary_1_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=1, average='binary') # f1_binary_0_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=0, average='binary') # f1_micro_gnn = f1_score(labels, np.array(gnn_pred_list), average='micro') # f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro') conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list)) tn, fp, fn, tp = conf_gnn.ravel() # gmean_gnn = conf_gmean(conf_gnn) # print(f" GNN F1-binary-1: {f1_binary_1_gnn:.4f}\tF1-binary-0: {f1_binary_0_gnn:.4f}"+ # f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}") # print(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}") # return f1_macro_gnn, f1_binary_1_gnn, f1_binary_0_gnn, auc_gnn, gmean_gnn sorted_gnn_prob_list = list(gnn_prob_list) sorted_labels_list = list(labels) combined = list(zip(sorted_gnn_prob_list, sorted_labels_list)) sorted_combined = sorted(combined, key=lambda x: x[0], reverse=True) sorted_gnn_prob_list, sorted_labels_list = zip(*sorted_combined) print("gnn_prob_list",sorted_gnn_prob_list[:10]) print("labels", sorted_labels_list[:10]) sorted_conf_gnn = confusion_matrix(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres)) sorted_tn, sorted_fp, sorted_fn, sorted_tp = sorted_conf_gnn.ravel() # sorted_auc_gnn = roc_auc_score(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres)) print("sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn",sorted_tn, sorted_fp, sorted_fn, sorted_tp, auc_gnn) return tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, auc_gnn # return tn, fp, fn, tp, auc_gnn def prob2pred(y_prob, thres=0.5): """ Convert probability to predicted results according to given threshold :param y_prob: numpy array of probability in [0, 1] :param thres: binary classification threshold, default 0.5 :returns: the predicted result with the same shape as y_prob """ y_pred = np.zeros_like(y_prob, dtype=np.int32) y_pred[y_prob >= thres] = 1 y_pred[y_prob < thres] = 0 return y_pred def test_pcgnn(test_cases, labels, model, batch_size, thres=0.5): """ Test the performance of PC-GNN and its variants :param test_cases: a list of testing node :param labels: a list of testing node labels :param model: the GNN model :param batch_size: number nodes in a batch :returns: the AUC and Recall of GNN and Simi modules """ test_batch_num = int(len(test_cases) / batch_size) + 1 f1_gnn = 0.0 acc_gnn = 0.0 recall_gnn = 0.0 f1_label1 = 0.0 acc_label1 = 0.00 recall_label1 = 0.0 gnn_pred_list = [] gnn_prob_list = [] label_list1 = [] for iteration in range(test_batch_num): i_start = iteration * batch_size i_end = min((iteration + 1) * batch_size, len(test_cases)) batch_nodes = test_cases[i_start:i_end] batch_label = labels[i_start:i_end] gnn_prob, label_prob1 = model.to_prob(batch_nodes, batch_label, train_flag=False) gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1] gnn_pred = prob2pred(gnn_prob_arr, thres) f1_label1 += f1_score(batch_label, label_prob1.data.cpu().numpy().argmax(axis=1), average="macro") acc_label1 += accuracy_score(batch_label, label_prob1.data.cpu().numpy().argmax(axis=1)) recall_label1 += recall_score(batch_label, label_prob1.data.cpu().numpy().argmax(axis=1), average="macro") gnn_pred_list.extend(gnn_pred.tolist()) gnn_prob_list.extend(gnn_prob_arr.tolist()) label_list1.extend(label_prob1.data.cpu().numpy()[:, 1].tolist()) auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list)) # ap_gnn = average_precision_score(labels, np.array(gnn_prob_list)) # auc_label1 = roc_auc_score(labels, np.array(label_list1)) # ap_label1 = average_precision_score(labels, np.array(label_list1)) # f1_binary_1_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=1, average='binary') # f1_binary_0_gnn = f1_score(labels, np.array(gnn_pred_list), pos_label=0, average='binary') # f1_micro_gnn = f1_score(labels, np.array(gnn_pred_list), average='micro') # f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro') conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list)) tn, fp, fn, tp = conf_gnn.ravel() # gmean_gnn = conf_gmean(conf_gnn) # print(f" GNN F1-binary-1: {f1_binary_1_gnn:.4f}\tF1-binary-0: {f1_binary_0_gnn:.4f}"+ # f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}") # print(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}") # print(f"Label1 F1: {f1_label1 / test_batch_num:.4f}\tAccuracy: {acc_label1 / test_batch_num:.4f}"+ # f"\tRecall: {recall_label1 / test_batch_num:.4f}\tAUC: {auc_label1:.4f}\tAP: {ap_label1:.4f}") # return f1_macro_gnn, f1_binary_1_gnn, f1_binary_0_gnn, auc_gnn, gmean_gnn sorted_gnn_prob_list = list(gnn_prob_list) sorted_labels_list = list(labels) combined = list(zip(sorted_gnn_prob_list, sorted_labels_list)) sorted_combined = sorted(combined, key=lambda x: x[0], reverse=True) sorted_gnn_prob_list, sorted_labels_list = zip(*sorted_combined) print("gnn_prob_list",sorted_gnn_prob_list[:10]) print("labels", sorted_labels_list[:10]) sorted_conf_gnn = confusion_matrix(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres)) sorted_tn, sorted_fp, sorted_fn, sorted_tp = sorted_conf_gnn.ravel() sorted_auc_gnn = roc_auc_score(sorted_labels_list[:10], prob2pred(np.array(sorted_gnn_prob_list[:10]), thres)) print("sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn",sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn) return tn, fp, fn, tp, auc_gnn, sorted_tn, sorted_fp, sorted_fn, sorted_tp, sorted_auc_gnn def conf_gmean(conf): tn, fp, fn, tp = conf.ravel() return (tp*tn/((tp+fn)*(tn+fp)))**0.5 ``` # model.py ```python! import torch import torch.nn as nn from torch.nn import init """ PC-GNN Model Paper: Pick and Choose: A GNN-based Imbalanced Learning Approach for Fraud Detection Modified from https://github.com/YingtongDou/CARE-GNN """ class PCALayer(nn.Module): """ One Pick-Choose-Aggregate layer """ def __init__(self, num_classes, inter1, lambda_1): """ Initialize the PC-GNN model :param num_classes: number of classes (2 in our paper) :param inter1: the inter-relation aggregator that output the final embedding """ super(PCALayer, self).__init__() self.inter1 = inter1 self.xent = nn.CrossEntropyLoss() # the parameter to transform the final embedding self.weight = nn.Parameter(torch.FloatTensor(num_classes, inter1.embed_dim)) init.xavier_uniform_(self.weight) self.lambda_1 = lambda_1 self.epsilon = 0.1 def forward(self, nodes, labels, train_flag=True): embeds1, label_scores = self.inter1(nodes, labels, train_flag) scores = self.weight.mm(embeds1) return scores.t(), label_scores def to_prob(self, nodes, labels, train_flag=True): gnn_logits, label_logits = self.forward(nodes, labels, train_flag) gnn_scores = torch.sigmoid(gnn_logits) label_scores = torch.sigmoid(label_logits) return gnn_scores, label_scores # def loss(self, nodes, labels, train_flag=True): # gnn_scores, label_scores = self.forward(nodes, labels, train_flag) # # Simi loss, Eq. (7) in the paper # label_loss = self.xent(label_scores, labels.squeeze()) # # GNN loss, Eq. (10) in the paper # gnn_loss = self.xent(gnn_scores, labels.squeeze()) # # the loss function of PC-GNN, Eq. (11) in the paper # final_loss = gnn_loss + self.lambda_1 * label_loss # return final_loss def focal_loss(self, input_probs, targets, alpha=0.5, gamma=2.0): pt = torch.exp(-self.xent(input_probs, targets)) # 預測概率的指數部分 focal_term = (1 - pt) ** gamma # Focal Loss 中的調整項 alpha_term = alpha * (1 - pt) # 平衡因子與 Focal Loss 的結合 focal_loss = alpha_term * focal_term * self.xent(input_probs, targets) return focal_loss.mean() # 平均 Focal Loss def loss(self, nodes, labels, train_flag=True): gnn_scores, label_scores = self.forward(nodes, labels, train_flag) # Simi loss, Eq. (7) in the paper label_loss = self.focal_loss(label_scores, labels.squeeze()) # GNN loss, Eq. (10) in the paper gnn_loss = self.focal_loss(gnn_scores, labels.squeeze()) # the loss function of PC-GNN, Eq. (11) in the paper final_loss = gnn_loss + self.lambda_1 * label_loss return final_loss ```