#tensorflow #jax
Вопрос:
Я могу запустить этот файл vit_jax.ipynb на colab, провести обучение и провести свои эксперименты, но когда я пытаюсь воспроизвести его в своем кластере, я получаю ошибку во время обучения, приведенного ниже. Однако прямой проход для вычисления точности отлично работает в моем кластере.
У меня есть 4 GTX 1080 с версией CUDA10.1 в моем кластере и с использованием tensorflow==2.4.0 и jax[cuda101]==0.2.18. Я запускаю это как записную книжку jupyter из контейнера docker.
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
11 opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12 opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
13 losses.append(loss_repl[0])
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
182 try:
--> 183 return fun(*args, **kwargs)
184 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in f_pmapped(*args, **kwargs)
1638 name=flat_fun.__name__, donated_invars=tuple(donated_invars),
-> 1639 global_arg_shapes=tuple(global_arg_shapes_flat))
1640 return tree_unflatten(out_tree(), out)
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
1620 assert len(params['in_axes']) == len(args)
-> 1621 return call_bind(self, fun, *args, **params)
1622
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1551 tracers = map(top_trace.full_raise, args)
-> 1552 outs = primitive.process(top_trace, fun, tracers, params)
1553 return map(full_lower, apply_todos(env_trace_todo(), outs))
/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
1623 def process(self, trace, fun, tracers, params):
-> 1624 return trace.process_map(self, fun, tracers, params)
1625
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
606 def process_call(self, primitive, f, tracers, params):
--> 607 return primitive.impl(f, *tracers, **params)
608 process_map = process_call
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)
636 ("fingerprint", fingerprint))
--> 637 return compiled_fun(*args)
638
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
1159 input_bufs = in_handler(args)
-> 1160 out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
1161 if xla.needs_check_special():
UnfilteredStackTrace: RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-57-176d6124ae02> in <module>()
10
11 opt_repl, loss_repl, update_rng_repl = update_fn_repl(
---> 12 opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)
13 losses.append(loss_repl[0])
14 lrs.append(lr_fn(step))
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in execute_replicated(compiled, backend, in_handler, out_handler, *args)
1158 def execute_replicated(compiled, backend, in_handler, out_handler, *args):
1159 input_bufs = in_handler(args)
-> 1160 out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
1161 if xla.needs_check_special():
1162 for bufs in out_bufs:
RuntimeError: Internal: external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nccl_utils.cc:203: NCCL operation ncclGroupEnd() failed: unhandled system error: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
Пожалуйста, дайте мне знать, сталкивался ли кто-нибудь с этой проблемой раньше? Или есть какой-нибудь способ решить эту проблему?
Ответ №1:
Без дополнительной информации трудно сказать наверняка, но эта ошибка может быть вызвана нехваткой памяти GPU. В зависимости от ваших локальных настроек, вы можете исправить это, увеличив долю памяти GPU, зарезервированной XLA, например, установив XLA_PYTHON_CLIENT_MEM_FRACTION
системную переменную в 0.9
значение или что-то подобное.
В качестве альтернативы вы можете попробовать запустить свой код на более мелкой проблеме, которая помещается в память на вашем локальном оборудовании.
Комментарии:
1. Я могу запустить один и тот же код на одном графическом процессоре, уменьшив размер пакета (как вы упомянули выше), но он не работает с несколькими графическими процессорами даже с небольшими данными.