Как вернуть определенные узлы с одинаковым значением степени в графике в PyTorch?

#python #python-3.x #pytorch

Вопрос:

У меня есть функция, которая строит график и находит степень каждого отдельного узла, а затем вычисляет количество узлов, имеющих одинаковое значение степени:

 import torch
from torch_geometric.utils import degree

def fun(self, graph):
        
        n = graph.num_nodes
        d = degree(graph.edge_index[1], n, dtype=torch.long)
        counts = torch.bincount(d)

        return counts
 

Вышеуказанная функция работает нормально. Но я хочу, чтобы он находил только узлы, у которых степень меньше 50, а затем возвращал количество узлов (степени:

 def fun(self, graph):
        
        n = graph.num_nodes
        deg = degree(graph.edge_index[1], n, dtype=torch.long)
        less_than_fifty = [i if i < 50 else 0 for i in deg]
        counts = torch.bincount(less_than_fifty)

        return counts

 

После его запуска возникла следующая ошибка:

 TypeError: bincount(): argument 'input' (position 1) must be Tensor, not list
 

Поэтому я использовал тензор вместо списка следующим образом:

 def fun(self, graph):
        
        n = graph.num_nodes
        d = degree(graph.edge_index[1], n, dtype=torch.long)
        less_than_fifty = torch.tensor([i if i < 50 else 0 for i in deg])
        counts = torch.bincount(less_than_fifty)

        return counts
 

Но на этот раз возникла другая проблема. Я запустил код в Google Colab, из-за последней модификации (преобразования списка в тензор) Colab продолжал сбоить. Я использовал графический процессор на Colab. Я уверен, что причиной сбоя была строка less_than_fifty = torch.tensor([i if i < 50 else 0 for i in deg]) , так как всякий раз, когда я ее удалял, сбоя больше не было. Мой вопрос в том, как устранить проблемы?