# Jax XLA distributed runtime GPU test ### Environment docker with jaxlib installed from https://storage.googleapis.com/jax-temp-releases/jaxlib-0.1.71-cp39-none-manylinux2010_x86_64.whl + tf-nightly on a node with 2 A100 GPUs. ### Script ``` # zhangqiaorjc@google.com import functools from absl import app from absl import flags from absl import logging import jax from jax.lib import xla_extension as xc flags.DEFINE_string('server_ip', '', help='server ip addr') flags.DEFINE_integer('server_port', 36657, help='server ip port') flags.DEFINE_integer('num_hosts', 1, help='num of hosts' ) flags.DEFINE_integer('host_idx', 0, help='index of current host' ) FLAGS = flags.FLAGS def connect_to_gpu_cluster(): service = None if FLAGS.host_idx == 0: addr = f'localhost:{FLAGS.server_port}' server_addr = f'{FLAGS.server_ip}:{FLAGS.server_port}' logging.info('starting service on %s', addr) service = xc.get_distributed_runtime_service(addr, FLAGS.num_hosts) logging.info('connecting to service on %s', server_addr) dist_client = xc.get_distributed_runtime_client(server_addr, FLAGS.host_idx) # on a different process change 0 to 1 or 2... # register dist gpu backend factory = functools.partial(jax.lib.xla_client.make_gpu_client, dist_client, FLAGS.host_idx) jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300) return service def main(argv): service = connect_to_gpu_cluster() logging.info('gpu cluster connected') logging.info('devices %s', jax.devices()) logging.info('local devices %s', jax.local_devices()) logging.info('shutting down gpu cluster...') del service if __name__ == '__main__': app.run(main) ``` ### Error log ``` /workspace/sandbox/jax_gpu# python3.9 jax_gpu.py --host_idx=0 I0824 13:02:13.397778 140149126596416 jax_gpu.py:24] starting service on localhost:36657 I0824 13:02:13.399801 140149126596416 jax_gpu.py:27] connecting to service on :36657 I0824 13:02:13.399940 140149126596416 jax_gpu.py:38] gpu cluster connected I0824 13:02:13.480894 140149126596416 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: I0824 13:02:13.963398 140149126596416 xla_bridge.py:236] Unable to initialize backend 'gpu': FAILED_PRECONDITION: EnumerateDevices() called when client not connected. I0824 13:02:13.964243 140149126596416 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available. W0824 13:02:13.964388 140149126596416 xla_bridge.py:240] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) I0824 13:02:13.964455 140149126596416 jax_gpu.py:40] devices [<jaxlib.xla_extension.Device object at 0x7f74d530e530>] I0824 13:02:13.964586 140149126596416 jax_gpu.py:41] local devices [<jaxlib.xla_extension.Device object at 0x7f74d530e530>] I0824 13:02:13.964632 140149126596416 jax_gpu.py:43] shutting down gpu cluster... ``` If before the line `service = connect_to_gpu_cluster()` I add anything that initialize jax xla backend such as jax.devices() or jax.lib.xla_bridge.get_backend().platform, it will work: ``` /workspace/sandbox/jax_gpu# python3.9 jax_gpu.py --host_idx=0 I0824 13:01:26.497866 139928270825280 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: I0824 13:01:27.124426 139928270825280 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available. gpu I0824 13:01:27.124691 139928270825280 jax_gpu.py:24] starting service on localhost:36657 I0824 13:01:27.126261 139928270825280 jax_gpu.py:27] connecting to service on :36657 I0824 13:01:27.126422 139928270825280 jax_gpu.py:39] gpu cluster connected I0824 13:01:27.126481 139928270825280 jax_gpu.py:41] devices [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)] I0824 13:01:27.126769 139928270825280 jax_gpu.py:42] local devices [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)] I0824 13:01:27.126814 139928270825280 jax_gpu.py:44] shutting down gpu cluster... ``` ### Verbose log ``` /workspace/sandbox/jax_gpu# TF_CPP_MIN_LOG_LEVEL=0 python3.9 jax_gpu.py --v=1 --host_idx=0 I0824 13:03:28.986649 140599124092736 jax_gpu.py:24] starting service on localhost:36657 2021-08-24 13:03:28.988447: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:369] Jax service listening on localhost:36657 I0824 13:03:28.988496 140599124092736 jax_gpu.py:27] connecting to service on :36657 I0824 13:03:28.988650 140599124092736 jax_gpu.py:38] gpu cluster connected I0824 13:03:28.988715 140599124092736 xla_bridge.py:214] Initializing backend 'interpreter' 2021-08-24 13:03:28.989316: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x27213c0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices: 2021-08-24 13:03:28.989337: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Interpreter, <undefined> I0824 13:03:29.015982 140599124092736 xla_bridge.py:224] Backend 'interpreter' initialized I0824 13:03:29.016132 140599124092736 xla_bridge.py:214] Initializing backend 'cpu' 2021-08-24 13:03:29.069008: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:163] TfrtCpuClient created. I0824 13:03:29.069246 140599124092736 xla_bridge.py:224] Backend 'cpu' initialized I0824 13:03:29.069410 140599124092736 xla_bridge.py:214] Initializing backend 'tpu_driver' I0824 13:03:29.069650 140599124092736 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: I0824 13:03:29.069706 140599124092736 xla_bridge.py:214] Initializing backend 'gpu' 2021-08-24 13:03:29.818912: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x2f51de0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: 2021-08-24 13:03:29.818998: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): A100-SXM4-40GB, Compute Capability 8.0 2021-08-24 13:03:29.819008: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (1): A100-SXM4-40GB, Compute Capability 8.0 2021-08-24 13:03:29.825085: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:309] Using BFC allocator. 2021-08-24 13:03:29.825168: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:268] XLA backend allocating 2468944281 bytes on device 0 for BFCAllocator. 2021-08-24 13:03:29.826200: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:268] XLA backend allocating 2468944281 bytes on device 1 for BFCAllocator. I0824 13:03:29.837855 140599124092736 xla_bridge.py:236] Unable to initialize backend 'gpu': FAILED_PRECONDITION: EnumerateDevices() called when client not connected. I0824 13:03:29.838031 140599124092736 xla_bridge.py:214] Initializing backend 'tpu' 2021-08-24 13:03:29.838459: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found. I0824 13:03:29.838530 140599124092736 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available. W0824 13:03:29.838790 140599124092736 xla_bridge.py:240] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) I0824 13:03:29.838853 140599124092736 jax_gpu.py:40] devices [<jaxlib.xla_extension.Device object at 0x7fdceb5c94b0>] I0824 13:03:29.838979 140599124092736 jax_gpu.py:41] local devices [<jaxlib.xla_extension.Device object at 0x7fdceb5c94b0>] I0824 13:03:29.839026 140599124092736 jax_gpu.py:43] shutting down gpu cluster... 2021-08-24 13:03:29.839062: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:379] Jax service shutting down ```