# 訓練檔案結構 ## 標準配置 Train_project (資料夾) * app(資料夾) * custom(資料夾) * 自定義的python訓練檔...... * config(資料夾) * config_fed_client.json * config_fed_server.json * meta.json ## 可以自行調整的參數 config_fed_client.json * lr(learning rate): 0.001 ~ 0.1 (建議) * epochs: 1 ~ 50 * data_path: 檔案路徑 config_fed_server.json * min_clients * num_rounds ## 設定檔 config_fed_client.json ```json= { "format_version": 2, "executors": [ { //執行的任務名稱 "tasks": ["train", "submit_model"], "executor": { // 執行的物件路徑 "path": "cifar10trainer.Cifar10Trainer", // 對應python檔的參數 "args": { "lr": 0.01, "epochs": 1 } } }, { "tasks": ["validate"], "executor": { "path": "cifar10validator.Cifar10Validator", "args": { } } } ], "task_result_filters": [ ], "task_data_filters": [ ], "components": [ ] ``` config_fed_server.json ```json= { "format_version": 2, "server": { "heart_beat_timeout": 600 }, "task_data_filters": [], "task_result_filters": [], "components": [ { "id": "persistor", "path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor", "args": { "model": { "path": "torchvision.models.vgg19" } } }, { "id": "shareable_generator", "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator", "args": {} }, { "id": "aggregator", "path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator", "args": { "expected_data_kind": "WEIGHTS" } }, { "id": "model_locator", "path": "pt_model_locator.PTModelLocator", "args": {} }, { "id": "json_generator", "path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator", "args": {} } ], "workflows": [ { "id": "scatter_and_gather", "name": "ScatterAndGather", "args": { "min_clients" : 2, // 最小的client數目 "num_rounds" : 2, // 聚合次數 "start_round": 0, "wait_time_after_min_received": 10, "aggregator_id": "aggregator", "persistor_id": "persistor", "shareable_generator_id": "shareable_generator", "train_task_name": "train", "train_timeout": 0 // 最多的等待時間 -- 0 代表不設置 } }, { "id": "cross_site_validate", "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator" } } ] } ``` meta.json ```json= { // 專案名稱 "name": "VGG19", "resource_spec": {}, // 最小的client數目 "min_clients" : 2, // 要部屬的程式 "deploy_map": { "app": [ "@ALL" ] } } ``` ## 範例專案: VGG-19 [GitHub](https://github.com/Kane-ouvic/nvflare_models/tree/main/VGG19) 基本上只需要修改以下三個文件,就可以換成其他模型進行訓練,註解有標明需要修改的地方,若要加入其他的功能需自行撰寫。 這邊做的是影像分類的任務使用的資料集(dataset)是cifar-10: config_fed_server.json ```json= { "format_version": 2, "server": { "heart_beat_timeout": 600 }, "task_data_filters": [], "task_result_filters": [], "components": [ { "id": "persistor", "path": "nvflare.app_common.pt.pt_file_model_persistor.PTFileModelPersistor", "args": { "model": { // 修改成模型的路徑 若要自訂layer可以使用simple_network.py "path": "torchvision.models.vgg19" } } }, { "id": "shareable_generator", "path": "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator", "args": {} }, { "id": "aggregator", "path": "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator", "args": { "expected_data_kind": "WEIGHTS" } }, { "id": "model_locator", "path": "pt_model_locator.PTModelLocator", "args": {} }, { "id": "json_generator", "path": "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator", "args": {} } ], "workflows": [ { "id": "scatter_and_gather", "name": "ScatterAndGather", "args": { "min_clients" : 2, "num_rounds" : 2, "start_round": 0, "wait_time_after_min_received": 10, "aggregator_id": "aggregator", "persistor_id": "persistor", "shareable_generator_id": "shareable_generator", "train_task_name": "train", "train_timeout": 0 } }, { "id": "cross_site_validate", "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator" } } ] } ``` cifar10trainer.py ```python= import os.path import torch import torchvision from pt_constants import PTConstants from simple_network import SimpleNetwork from torch import nn from torch.optim import SGD from torch.utils.data.dataloader import DataLoader from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import ReservedKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.pt.pt_fed_utils import PTModelPersistenceFormatManager class Cifar10Trainer(Executor): def __init__( self, data_path="./data", lr=0.01, epochs=5, train_task_name=AppConstants.TASK_TRAIN, submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, exclude_vars=None, ): """Cifar10 Trainer handles train and submit_model tasks. During train_task, it trains a simple network on CIFAR10 dataset. For submit_model task, it sends the locally trained model (if present) to the server. Args: lr (float, optional): Learning rate. Defaults to 0.01 epochs (int, optional): Epochs. Defaults to 5 train_task_name (str, optional): Task name for train task. Defaults to "train". submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model". exclude_vars (list): List of variables to exclude during model loading. """ super().__init__() self._lr = lr self._epochs = epochs self._train_task_name = train_task_name self._submit_model_task_name = submit_model_task_name self._exclude_vars = exclude_vars # Training setup # 可修改成其他的預訓練模型 self.model = torchvision.models.vgg19() print(self.model) print(os.path.dirname(os.path.abspath(__file__))) print("======================================") self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.loss = nn.CrossEntropyLoss() self.optimizer = SGD(self.model.parameters(), lr=lr, momentum=0.9) # Create Cifar10 dataset for training. transforms = Compose( [ ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) self._train_dataset = CIFAR10( root=data_path, transform=transforms, download=True, train=True) self._train_loader = DataLoader( self._train_dataset, batch_size=4, shuffle=True) self._n_iterations = len(self._train_loader) # Setup the persistence manager to save PT model. # The default training configuration is used by persistence manager # in case no initial model is found. self._default_train_conf = { "train": {"model": type(self.model).__name__}} self.persistence_manager = PTModelPersistenceFormatManager( data=self.model.state_dict(), default_train_conf=self._default_train_conf ) def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: try: if task_name == self._train_task_name: # Get model weights try: dxo = from_shareable(shareable) except: self.log_error( fl_ctx, "Unable to extract dxo from shareable.") return make_reply(ReturnCode.BAD_TASK_DATA) # Ensure data kind is weights. if not dxo.data_kind == DataKind.WEIGHTS: self.log_error( fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") return make_reply(ReturnCode.BAD_TASK_DATA) # Convert weights to tensor. Run training torch_weights = {k: torch.as_tensor( v) for k, v in dxo.data.items()} self._local_train(fl_ctx, torch_weights, abort_signal) # Check the abort_signal after training. # local_train returns early if abort_signal is triggered. if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) # Save the local model after training. self._save_local_model(fl_ctx) # Get the new state dict and send as weights new_weights = self.model.state_dict() new_weights = {k: v.cpu().numpy() for k, v in new_weights.items()} outgoing_dxo = DXO( data_kind=DataKind.WEIGHTS, data=new_weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations}, ) return outgoing_dxo.to_shareable() elif task_name == self._submit_model_task_name: # Load local model ml = self._load_local_model(fl_ctx) # Get the model parameters and create dxo from it dxo = model_learnable_to_dxo(ml) return dxo.to_shareable() else: return make_reply(ReturnCode.TASK_UNKNOWN) except Exception as e: self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") return make_reply(ReturnCode.EXECUTION_EXCEPTION) def _local_train(self, fl_ctx, weights, abort_signal): # Set the model weights self.model.load_state_dict(state_dict=weights) # Basic training self.model.train() for epoch in range(self._epochs): running_loss = 0.0 for i, batch in enumerate(self._train_loader): if abort_signal.triggered: # If abort_signal is triggered, we simply return. # The outside function will check it again and decide steps to take. return images, labels = batch[0].to( self.device), batch[1].to(self.device) self.optimizer.zero_grad() predictions = self.model(images) cost = self.loss(predictions, labels) cost.backward() self.optimizer.step() running_loss += cost.cpu().detach().numpy() / images.size()[0] if i % 3000 == 0: self.log_info( fl_ctx, f"Epoch: {epoch}/{self._epochs}, Iteration: {i}, " f"Loss: {running_loss/3000}" ) running_loss = 0.0 def _save_local_model(self, fl_ctx: FLContext): run_dir = fl_ctx.get_engine().get_workspace().get_run_dir( fl_ctx.get_prop(ReservedKey.RUN_NUM)) models_dir = os.path.join(run_dir, PTConstants.PTModelsDir) if not os.path.exists(models_dir): os.makedirs(models_dir) model_path = os.path.join(models_dir, PTConstants.PTLocalModelName) ml = make_model_learnable(self.model.state_dict(), {}) self.persistence_manager.update(ml) torch.save(self.persistence_manager.to_persistence_dict(), model_path) def _load_local_model(self, fl_ctx: FLContext): run_dir = fl_ctx.get_engine().get_workspace().get_run_dir( fl_ctx.get_prop(ReservedKey.RUN_NUM)) models_dir = os.path.join(run_dir, PTConstants.PTModelsDir) if not os.path.exists(models_dir): return None model_path = os.path.join(models_dir, PTConstants.PTLocalModelName) self.persistence_manager = PTModelPersistenceFormatManager( data=torch.load(model_path), default_train_conf=self._default_train_conf ) ml = self.persistence_manager.to_model_learnable( exclude_vars=self._exclude_vars) return ml ``` cifar10validator.py ```python= import torch import torchvision from simple_network import SimpleNetwork from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor from nvflare.apis.dxo import DXO, DataKind, from_shareable from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.app_constant import AppConstants class Cifar10Validator(Executor): def __init__(self, data_path="~/data", validate_task_name=AppConstants.TASK_VALIDATION): super().__init__() self._validate_task_name = validate_task_name # Setup the model # 可修改成其他的預訓練模型 self.model = torchvision.models.vgg19() self.device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") self.model.to(self.device) # Preparing the dataset for testing. transforms = Compose( [ ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) test_data = CIFAR10(root=data_path, train=False, transform=transforms) self._test_loader = DataLoader(test_data, batch_size=4, shuffle=False) def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: if task_name == self._validate_task_name: model_owner = "?" try: try: dxo = from_shareable(shareable) except: self.log_error( fl_ctx, "Error in extracting dxo from shareable.") return make_reply(ReturnCode.BAD_TASK_DATA) # Ensure data_kind is weights. if not dxo.data_kind == DataKind.WEIGHTS: self.log_exception( fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.") return make_reply(ReturnCode.BAD_TASK_DATA) # Extract weights and ensure they are tensor. model_owner = shareable.get_header( AppConstants.MODEL_OWNER, "?") weights = {k: torch.as_tensor( v, device=self.device) for k, v in dxo.data.items()} # Get validation accuracy val_accuracy = self._validate(weights, abort_signal) if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) self.log_info( fl_ctx, f"Accuracy when validating {model_owner}'s model on" f" {fl_ctx.get_identity_name()}" f"s data: {val_accuracy}", ) dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy}) return dxo.to_shareable() except: self.log_exception( fl_ctx, f"Exception in validating model from {model_owner}") return make_reply(ReturnCode.EXECUTION_EXCEPTION) else: return make_reply(ReturnCode.TASK_UNKNOWN) def _validate(self, weights, abort_signal): self.model.load_state_dict(weights) self.model.eval() correct = 0 total = 0 with torch.no_grad(): for i, (images, labels) in enumerate(self._test_loader): if abort_signal.triggered: return 0 images, labels = images.to(self.device), labels.to(self.device) output = self.model(images) _, pred_label = torch.max(output, 1) correct += (pred_label == labels).sum().item() total += images.size()[0] metric = correct / float(total) return metric ```