#python #gradient #torch #loss-function #backpropagation
Вопрос:
def svm_loss_naive(W, X, y, reg):
"""
Structured SVM loss function, naive implementation (with loops).
Inputs have dimension D, there are C classes, and we operate on minibatches
of N examples. When you implment the regularization over W, please DO NOT
multiply the regularization term by 1/2 (no coefficient).
Inputs:
- W: A PyTorch tensor of shape (D, C) containing weights.
- X: A PyTorch tensor of shape (N, D) containing a minibatch of data.
- y: A PyTorch tensor of shape (N,) containing training labels; y[i] = c means
that X[i] has label c, where 0 <= c < C.
- reg: (float) regularization strength
Returns a tuple of:
- loss as torch scalar
- gradient of loss with respect to weights W; a tensor of same shape as W
"""
dW = torch.zeros_like(W) # initialize the gradient as zero
# compute the loss and the gradient
num_classes = W.shape[1]
num_train = X.shape[0]
loss = 0.0
# print(W.t().shape)
# print(X[0].shape)
# print(W.t(х).mv(X[0]).shape)
for i in range(num_train):
scores = W.t().mv(X[i])
correct_class_score = scores[y[i]]
for j in range(num_classes):
if j == y[i]:
continue
margin = scores[j] - correct_class_score 1 # note delta = 1
if margin > 0:
loss = margin
#######################################################################
# TODO: #
# Compute the gradient of the loss function and store it dW. (part 1) #
# Rather than first computing the loss and then computing the #
# derivative, it is simple to compute the derivative at the same time #
# that the loss is being computed. #
#######################################################################
# Replace "pass" statement with your code
dW[:, j] = X[i]
dW[:, y[i]]-= X[i]
#######################################################################
# END OF YOUR CODE #
#######################################################################
Задача состоит в том, чтобы вычислить градиент функции потерь. Я не мог решить это сам, поэтому нашел решение на github. Многие подходы были точно такими же, как и описанный выше. Даже сейчас я все еще не понимаю этого. Небольшая помощь, пожалуйста, будет оценена по достоинству.