# Jax XLA distributed runtime GPU test
### Environment
docker with jaxlib installed from https://storage.googleapis.com/jax-temp-releases/jaxlib-0.1.71-cp39-none-manylinux2010_x86_64.whl + tf-nightly on a node with 2 A100 GPUs.
### Script
```
# zhangqiaorjc@google.com
import functools
from absl import app
from absl import flags
from absl import logging
import jax
from jax.lib import xla_extension as xc
flags.DEFINE_string('server_ip', '', help='server ip addr')
flags.DEFINE_integer('server_port', 36657, help='server ip port')
flags.DEFINE_integer('num_hosts', 1, help='num of hosts' )
flags.DEFINE_integer('host_idx', 0, help='index of current host' )
FLAGS = flags.FLAGS
def connect_to_gpu_cluster():
service = None
if FLAGS.host_idx == 0:
addr = f'localhost:{FLAGS.server_port}'
server_addr = f'{FLAGS.server_ip}:{FLAGS.server_port}'
logging.info('starting service on %s', addr)
service = xc.get_distributed_runtime_service(addr, FLAGS.num_hosts)
logging.info('connecting to service on %s', server_addr)
dist_client = xc.get_distributed_runtime_client(server_addr, FLAGS.host_idx) # on a different process change 0 to 1 or 2...
# register dist gpu backend
factory = functools.partial(jax.lib.xla_client.make_gpu_client, dist_client, FLAGS.host_idx)
jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300)
return service
def main(argv):
service = connect_to_gpu_cluster()
logging.info('gpu cluster connected')
logging.info('devices %s', jax.devices())
logging.info('local devices %s', jax.local_devices())
logging.info('shutting down gpu cluster...')
del service
if __name__ == '__main__':
app.run(main)
```
### Error log
```
/workspace/sandbox/jax_gpu# python3.9 jax_gpu.py --host_idx=0
I0824 13:02:13.397778 140149126596416 jax_gpu.py:24] starting service on localhost:36657
I0824 13:02:13.399801 140149126596416 jax_gpu.py:27] connecting to service on :36657
I0824 13:02:13.399940 140149126596416 jax_gpu.py:38] gpu cluster connected
I0824 13:02:13.480894 140149126596416 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0824 13:02:13.963398 140149126596416 xla_bridge.py:236] Unable to initialize backend 'gpu': FAILED_PRECONDITION: EnumerateDevices() called when client not connected.
I0824 13:02:13.964243 140149126596416 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
W0824 13:02:13.964388 140149126596416 xla_bridge.py:240] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0824 13:02:13.964455 140149126596416 jax_gpu.py:40] devices [<jaxlib.xla_extension.Device object at 0x7f74d530e530>]
I0824 13:02:13.964586 140149126596416 jax_gpu.py:41] local devices [<jaxlib.xla_extension.Device object at 0x7f74d530e530>]
I0824 13:02:13.964632 140149126596416 jax_gpu.py:43] shutting down gpu cluster...
```
If before the line `service = connect_to_gpu_cluster()` I add anything that initialize jax xla backend such as jax.devices() or jax.lib.xla_bridge.get_backend().platform, it will work:
```
/workspace/sandbox/jax_gpu# python3.9 jax_gpu.py --host_idx=0
I0824 13:01:26.497866 139928270825280 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0824 13:01:27.124426 139928270825280 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
gpu
I0824 13:01:27.124691 139928270825280 jax_gpu.py:24] starting service on localhost:36657
I0824 13:01:27.126261 139928270825280 jax_gpu.py:27] connecting to service on :36657
I0824 13:01:27.126422 139928270825280 jax_gpu.py:39] gpu cluster connected
I0824 13:01:27.126481 139928270825280 jax_gpu.py:41] devices [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
I0824 13:01:27.126769 139928270825280 jax_gpu.py:42] local devices [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
I0824 13:01:27.126814 139928270825280 jax_gpu.py:44] shutting down gpu cluster...
```
### Verbose log
```
/workspace/sandbox/jax_gpu# TF_CPP_MIN_LOG_LEVEL=0 python3.9 jax_gpu.py --v=1 --host_idx=0
I0824 13:03:28.986649 140599124092736 jax_gpu.py:24] starting service on localhost:36657
2021-08-24 13:03:28.988447: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:369] Jax service listening on localhost:36657
I0824 13:03:28.988496 140599124092736 jax_gpu.py:27] connecting to service on :36657
I0824 13:03:28.988650 140599124092736 jax_gpu.py:38] gpu cluster connected
I0824 13:03:28.988715 140599124092736 xla_bridge.py:214] Initializing backend 'interpreter'
2021-08-24 13:03:28.989316: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x27213c0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
2021-08-24 13:03:28.989337: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Interpreter, <undefined>
I0824 13:03:29.015982 140599124092736 xla_bridge.py:224] Backend 'interpreter' initialized
I0824 13:03:29.016132 140599124092736 xla_bridge.py:214] Initializing backend 'cpu'
2021-08-24 13:03:29.069008: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:163] TfrtCpuClient created.
I0824 13:03:29.069246 140599124092736 xla_bridge.py:224] Backend 'cpu' initialized
I0824 13:03:29.069410 140599124092736 xla_bridge.py:214] Initializing backend 'tpu_driver'
I0824 13:03:29.069650 140599124092736 xla_bridge.py:236] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0824 13:03:29.069706 140599124092736 xla_bridge.py:214] Initializing backend 'gpu'
2021-08-24 13:03:29.818912: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x2f51de0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2021-08-24 13:03:29.818998: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): A100-SXM4-40GB, Compute Capability 8.0
2021-08-24 13:03:29.819008: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (1): A100-SXM4-40GB, Compute Capability 8.0
2021-08-24 13:03:29.825085: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:309] Using BFC allocator.
2021-08-24 13:03:29.825168: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:268] XLA backend allocating 2468944281 bytes on device 0 for BFCAllocator.
2021-08-24 13:03:29.826200: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu_device.cc:268] XLA backend allocating 2468944281 bytes on device 1 for BFCAllocator.
I0824 13:03:29.837855 140599124092736 xla_bridge.py:236] Unable to initialize backend 'gpu': FAILED_PRECONDITION: EnumerateDevices() called when client not connected.
I0824 13:03:29.838031 140599124092736 xla_bridge.py:214] Initializing backend 'tpu'
2021-08-24 13:03:29.838459: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
I0824 13:03:29.838530 140599124092736 xla_bridge.py:236] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
W0824 13:03:29.838790 140599124092736 xla_bridge.py:240] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0824 13:03:29.838853 140599124092736 jax_gpu.py:40] devices [<jaxlib.xla_extension.Device object at 0x7fdceb5c94b0>]
I0824 13:03:29.838979 140599124092736 jax_gpu.py:41] local devices [<jaxlib.xla_extension.Device object at 0x7fdceb5c94b0>]
I0824 13:03:29.839026 140599124092736 jax_gpu.py:43] shutting down gpu cluster...
2021-08-24 13:03:29.839062: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/distributed/service.cc:379] Jax service shutting down
```