# SkyPilot Agent
## Overview
[SkyPilot](https://skypilot.readthedocs.io/en/latest/docs/index.html) is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution.
Flyte can utilize SkyPilot to execute a Flyte task on any cloud and allocate a GPU for you.
## Goal
- Run a Flyte task on any cloud with a cheapest machine.
- Run Flyte task on [spot instances](https://skypilot.readthedocs.io/en/latest/examples/spot-jobs.html) with automatic recovery from preemptions.
## Example
Any task with SkyPilot config will be submitted to other cloud by the SkyPilot agent.
In the below example, `t1` and `t2` task will reuse the same cluster.
```python=
from flytekit import task, workflow
from flytekitplugins.skypilot import SkyPilot
@task(task_config=SkyPilot(cluster_name="foo"))
def t1() -> int:
return 3 + 2
@task(task_config=SkyPilot(cluster_name="foo"))
def t2(a: int) -> int:
return a + 3
@workflow
def wf(a: int = 3):
t1()
t2(a=a)
```
## SkyPilot Task
This implentation is similar to [Databricks Task](https://github.com/flyteorg/flytekit/blob/ecc783566fa254fad31a9ccc3c443172955212bc/plugins/flytekit-spark/flytekitplugins/spark/task.py#L48).
```python=
@dataclass
class SkyPilot(object):
cluster_name: str
...
class SkyPilotFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SkyPilot]):
_TASK_TYPE = "skypilot"
def __init__(
self,
task_config: SkyPilot,
task_function: Callable,
container_image: Optional[Union[str, ImageSpec]] = None,
**kwargs,
):
super(PysparkFunctionTask, self).__init__(
task_config=task_config,
task_type=self.__init__,
task_function=task_function,
container_image=container_image,
**kwargs,
)
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return asdict(self.task_config)
```
## SkyPilot Agent
This agent submits jobs by leveraging SkyPilot. The `TaskTemplate` contains all the necessary job information to create a SkyPilot task specification, including the image, command, resources, input, and output paths.
```python=
from typing import Optional
from dataclasses import dataclass
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.extend.backend.base_agent import AsyncAgentBase, AgentRegistry, Resource, ResourceMeta
@dataclass
class SkyPilotMetadata(ResourceMeta):
"""
This is the metadata for the job.
"""
job_id: str
cluster_name: str
cloud: Optional[str]
region: Optional[str]
class SkyPilotAgent(AsyncAgentBase):
def __init__(self):
super().__init__(task_type_name="skypilot", metadata_type=SkyPilotMetadata)
def create(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> SkyPilotMetadata:
cluster_name = task_template.custom["cluster_name"]
resource = task_template.container.resource
image = task_template.container.image
task = sky.Task(run=task_template.container.args, image_id=image)
job_id = sky.launch(task=task, cluster_name=cluster_name, resource=resource)
return SkyPilotMetadata(job_id=job_id, cluster_name=cluster_name)
def get(self, resource_meta: SkyPilotMetadata, **kwargs) -> Resource:
phase, outputs = get_skypilot_job_status(...)
return Resource(phase=phase, outputs=outputs)
def delete(self, resource_meta: SkyPilotMetadata, **kwargs):
sky.down(resource_meta.cluster_name)
# To register the skypilot agent
AgentRegistry.register(SkyPilotAgent())
```