#python #pytorch
#питон #пыторч
Вопрос:
для добавления условий в VAE я использовал этот набор данных и метки, но в части кода я получил ошибку: вот мой генератор данных:
def gen_batch(BATCH_SIZE): labels = torch.randint(0, 8, (BATCH_SIZE,)).long().to(device) theta = (np.pi/4) * labels.float().to(device) centers = torch.stack((torch.cos(theta), torch.sin(theta)), dim = -1) noise = torch.randn_like(centers) * 0.1 return centers noise, labels def data_gen(BATCH_SIZE): #8 gaussians while 1: yield gen_batch(BATCH_SIZE) train_loader,train_labels = next(data_gen(args.batch_size))
и я получил ошибку в этой части:
def one_hot(labels, class_size): targets = torch.zeros(labels.size(0), class_size) for i, label in enumerate(labels): targets[i, label] = 1 return Variable(targets)
IndexError Traceback (most recent call last) lt;ipython-input-74-bbc3d925d933gt; in lt;modulegt; 7 for epoch in tqdm(range(1, args.epochs 1)): 8 ----gt; 9 a,b= train(epoch) 10 # with open("output.txt","a") as f: 11 # print("epoch:",epoch,",","re:",2 a,",","li:",b,",","los", file=f) lt;ipython-input-73-0f6a4be71f5cgt; in train(epoch) 7 break #100 batches per epoch 8 # data = data.to(device) ----gt; 9 data, cond = data.to(device), one_hot(cond, cond_dim).to(device) 10 optimizer.zero_grad() 11 recon_batch, mu, logvar = cvae(data, cond) lt;ipython-input-72-1088648ceae5gt; in one_hot(labels, class_size) 1 def one_hot(labels, class_size): ----gt; 2 targets = torch.zeros(labels.size(0), class_size) 3 for i, label in enumerate(labels): 4 targets[i, label] = 1 5 return Variable(targets) IndexError: dimension specified as 0 but tensor has no dimensions
что я должен сделать, чтобы это исправить?