Клип OpenAI (PyTorch) зависает при использовании многопроцессорной обработки

#python #pytorch #multiprocessing

Вопрос:

При запуске КЛИПА внутри многопроцессорной обработки.Процесс, система зависает, как только она достигает этапа предварительной обработки (на практике я предполагаю, что на самом деле это любая операция факела). Минимальный пример:

 import torch
import clip
from PIL import Image
import multiprocessing as mp
import sys

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


def infer():
  print("PREPROCESSING")
  sys.stdout.flush()
  image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)

  print("TOKENIZING")
  sys.stdout.flush()
  text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

  print("INFERRING")
  sys.stdout.flush()
  with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    print(f'{image_features.shape}')
    print(f'{text_features.shape}')
    sys.stdout.flush()

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

  print(
      f"Label probs: {probs}")  # prints: [[0.9927937  0.00421068 0.00299572]]
  sys.stdout.flush()


p = mp.Process(target=infer, daemon=True)
p.start()
p.join()
 

Этот пример является эквивалентом примера начального руководства по использованию в файле ЧТЕНИЯ клипа, но заключает вывод модели в процесс.

Вывод этого кода является:

 /home/amol/code/soot/debugging/clip_tests/env/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check th
at you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0                                                  
PREPROCESSING   
 

после чего он зависает. Есть какие-нибудь предложения?

РЕДАКТИРОВАТЬ Мне удалось сделать это еще меньше, покопавшись в КЛИПЕ, чтобы увидеть, в чем проблемы. Я добрался сюда:

 import os                                                                                                                                                                 
import urllib                                                                                                                                                             
                                                                                                                                                                          
from tqdm import tqdm                                                                                                                                                     
                                                                                                                                                                          
import torch                                                                                                                                                              
import clip                                                                                                                                                               
from PIL import Image                                                                                                                                                     
import multiprocessing as mp                                                                                                                                              
                                                                                                                                                                          
                                                                                                                                                                          
def _download(url, root=os.path.expanduser("~/.cache/clip")):                        
  os.makedirs(root, exist_ok=True)        
  download_target = os.path.join(root, os.path.basename(url))                        
                                                                                                                                                                          
  with urllib.request.urlopen(url) as source, open(download_target,                  
                                                   "wb") as output:                  
    with tqdm(total=int(source.info().get("Content-Length")),                        
              ncols=80,
              unit='iB',
              unit_scale=True) as loop:   
      while True:                         
        buffer = source.read(8192)        
        if not buffer:                                                                                                                                                    
          break                                                                      

        output.write(buffer)
        loop.update(len(buffer))

  return download_target


def load(device="cpu"):
  model_path = _download(
      "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"  # noqa
  )                                       
  model = torch.jit.load(model_path, map_location="cpu").eval()                      
  model = clip.model.build_model(model.state_dict()).to(device)
  
load()                                    


def test():                               
  print("GETTING IMAGE")
  im = Image.open("CLIP.png")
  print("CONVERTING")
  im = im.convert('RGB')
  print("MADE TENSOR")
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(im.tobytes()))                
  print("VIEW")                           
  img = img.view(im.size[1], im.size[0], len(im.getbands()))                         
  print("PERMUTING")                      
  img = img.permute((2, 0, 1)).contiguous()                                          
  print("DIV")                            
  img = img.float().div(255)
  print("UNSQUEEZE")                      
  img = img.unsqueeze(0)


p = mp.Process(target=test, daemon=True)
p.start()                                 
p.join()
 

Обратите внимание, что ничто, созданное вызовом load() или download (), на самом деле не используется. Кроме того, если я закомментирую строку build_model (), все будет работать. Что делает clip.model.build_model (), что вызывает проблемы?