SkyPilot Agent

Overview

SkyPilot is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution.

Flyte can utilize SkyPilot to execute a Flyte task on any cloud and allocate a GPU for you.

Goal

  • Run a Flyte task on any cloud with a cheapest machine.
  • Run Flyte task on spot instances with automatic recovery from preemptions.

Example

Any task with SkyPilot config will be submitted to other cloud by the SkyPilot agent.

In the below example, t1 and t2 task will reuse the same cluster.

from flytekit import task, workflow from flytekitplugins.skypilot import SkyPilot @task(task_config=SkyPilot(cluster_name="foo")) def t1() -> int: return 3 + 2 @task(task_config=SkyPilot(cluster_name="foo")) def t2(a: int) -> int: return a + 3 @workflow def wf(a: int = 3): t1() t2(a=a)

SkyPilot Task

This implentation is similar to Databricks Task.

@dataclass class SkyPilot(object): cluster_name: str ... class SkyPilotFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SkyPilot]): _TASK_TYPE = "skypilot" def __init__( self, task_config: SkyPilot, task_function: Callable, container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self.__init__, task_function=task_function, container_image=container_image, **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return asdict(self.task_config)

SkyPilot Agent

This agent submits jobs by leveraging SkyPilot. The TaskTemplate contains all the necessary job information to create a SkyPilot task specification, including the image, command, resources, input, and output paths.

from typing import Optional from dataclasses import dataclass from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.extend.backend.base_agent import AsyncAgentBase, AgentRegistry, Resource, ResourceMeta @dataclass class SkyPilotMetadata(ResourceMeta): """ This is the metadata for the job. """ job_id: str cluster_name: str cloud: Optional[str] region: Optional[str] class SkyPilotAgent(AsyncAgentBase): def __init__(self): super().__init__(task_type_name="skypilot", metadata_type=SkyPilotMetadata) def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs, ) -> SkyPilotMetadata: cluster_name = task_template.custom["cluster_name"] resource = task_template.container.resource image = task_template.container.image task = sky.Task(run=task_template.container.args, image_id=image) job_id = sky.launch(task=task, cluster_name=cluster_name, resource=resource) return SkyPilotMetadata(job_id=job_id, cluster_name=cluster_name) def get(self, resource_meta: SkyPilotMetadata, **kwargs) -> Resource: phase, outputs = get_skypilot_job_status(...) return Resource(phase=phase, outputs=outputs) def delete(self, resource_meta: SkyPilotMetadata, **kwargs): sky.down(resource_meta.cluster_name) # To register the skypilot agent AgentRegistry.register(SkyPilotAgent())