# SkyPilot Agent ## Overview [SkyPilot](https://skypilot.readthedocs.io/en/latest/docs/index.html) is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution. Flyte can utilize SkyPilot to execute a Flyte task on any cloud and allocate a GPU for you. ## Goal - Run a Flyte task on any cloud with a cheapest machine. - Run Flyte task on [spot instances](https://skypilot.readthedocs.io/en/latest/examples/spot-jobs.html) with automatic recovery from preemptions. ## Example Any task with SkyPilot config will be submitted to other cloud by the SkyPilot agent. In the below example, `t1` and `t2` task will reuse the same cluster. ```python= from flytekit import task, workflow from flytekitplugins.skypilot import SkyPilot @task(task_config=SkyPilot(cluster_name="foo")) def t1() -> int: return 3 + 2 @task(task_config=SkyPilot(cluster_name="foo")) def t2(a: int) -> int: return a + 3 @workflow def wf(a: int = 3): t1() t2(a=a) ``` ## SkyPilot Task This implentation is similar to [Databricks Task](https://github.com/flyteorg/flytekit/blob/ecc783566fa254fad31a9ccc3c443172955212bc/plugins/flytekit-spark/flytekitplugins/spark/task.py#L48). ```python= @dataclass class SkyPilot(object): cluster_name: str ... class SkyPilotFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SkyPilot]): _TASK_TYPE = "skypilot" def __init__( self, task_config: SkyPilot, task_function: Callable, container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self.__init__, task_function=task_function, container_image=container_image, **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return asdict(self.task_config) ``` ## SkyPilot Agent This agent submits jobs by leveraging SkyPilot. The `TaskTemplate` contains all the necessary job information to create a SkyPilot task specification, including the image, command, resources, input, and output paths. ```python= from typing import Optional from dataclasses import dataclass from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.extend.backend.base_agent import AsyncAgentBase, AgentRegistry, Resource, ResourceMeta @dataclass class SkyPilotMetadata(ResourceMeta): """ This is the metadata for the job. """ job_id: str cluster_name: str cloud: Optional[str] region: Optional[str] class SkyPilotAgent(AsyncAgentBase): def __init__(self): super().__init__(task_type_name="skypilot", metadata_type=SkyPilotMetadata) def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs, ) -> SkyPilotMetadata: cluster_name = task_template.custom["cluster_name"] resource = task_template.container.resource image = task_template.container.image task = sky.Task(run=task_template.container.args, image_id=image) job_id = sky.launch(task=task, cluster_name=cluster_name, resource=resource) return SkyPilotMetadata(job_id=job_id, cluster_name=cluster_name) def get(self, resource_meta: SkyPilotMetadata, **kwargs) -> Resource: phase, outputs = get_skypilot_job_status(...) return Resource(phase=phase, outputs=outputs) def delete(self, resource_meta: SkyPilotMetadata, **kwargs): sky.down(resource_meta.cluster_name) # To register the skypilot agent AgentRegistry.register(SkyPilotAgent()) ```