Обучение формирователя времени для классификации видео

#transformer #custom-training

Вопрос:

Мои входные данные-это карты объектов, а не необработанные изображения. и иметь форму : (4,50,1,1,256)
mini_batch=4 / frames=50 / channels=1 / H=1 / W= 256 Параметрами формирователя времени являются :

 dim = 128,
image_size = 256,
patch_size = 16,
num_frames = 50,
num_classes = 2,
depth = 12,
heads = 8,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.
)
 

Чтобы проверить, работает ли моя сеть, я попытался сделать ее более подходящей, используя только 6 обучающих данных и 2 проверочных данных той же формы, что и раньше (4,50,1,1,256) .
Но точность обучения, которую я получаю, колеблется и никогда не достигает значения >80%, и мои потери в обучении не уменьшаются, они всегда рядом 0.6900 - 06950

Моя функция и параметры обучения таковы:

 
    epochs = 300
    lr = 1e-3
    device = "cuda" if torch.cuda.is_available() else "cpu" 
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    def accuracy(y_pred, y_test):
       y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
       _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
       correct_pred = (y_pred_tags == y_test).float()
       acc = correct_pred.sum() / len(correct_pred)
       acc = torch.round(acc * 100)
       return acc
    history = defaultdict(list)
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_accuracy = 0
        model=model.train()
        for data, label in tqdm(train_loader):
            data = data
            label = label
            data=data.reshape(4,50,1,1,256)
            output = model(data)
            label=label.reshape(4,).to(torch.long)
            output = output / output.sum(0).expand_as(output)
            loss = criterion(output,label)
            acc=accuracy(output,label)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()        
            epoch_accuracy  = acc / len(train_loader)
            epoch_loss  = loss / len(train_loader)
        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            model=model.eval()
            for data, label in val_loader:
                data = data
                label=label.reshape(4,).to(torch.long)
                data=data.reshape(4,50,1,1,256)
                val_output = model(data)
                val_output = val_output / val_output.sum(0).expand_as(val_output)
                val_loss = criterion(val_output, label)
                val_acc=accuracy(val_output,label)
                optimizer.zero_grad()            
                epoch_val_accuracy  = acc / len(val_loader)
                epoch_val_loss  = val_loss / len(val_loader)
 

Я был бы признателен за любое предложение.
Спасибо