Как использовать насмешку для сравнения реальных результатов

#python #mocking #pytest #python-unittest #pytest-mock

Вопрос:

У меня есть пример класса, который считывает сохраненную модель тензорного потока и выполняет прогнозы

 class Sample():
   ## all it does is creates a new column with predictions

   def __init__(self, tf_model):
      self.tf_model = tf_model

   def tf_process(self, x):

       ##some other preprocessing
       x["tf_predictions"] = self.tf_model.predict(x)
       return x
   
   def predict(self, x):
       predictions = self.tf_process(x)
       return predictions
 

Код для тестирования без необходимости загрузки модели:

 import unittest
import pandas as pd
from unittest import TestCase, mock
from my_package.sample_model import Sample

class TestSample(unittest.TestCase):

   def test_predict(self):
      with mock.patch("Sample.tf_process") as process:
         process.return_value = pd.DataFrame("hardcoded_value")
         #to check: process.return_value = Output (Sample.predict())
         
 

Цель:

Чтобы сравнить process.return_value с Output of predict method in Sample , но для этого мне все равно нужно загрузить модель, я не понимаю, в чем здесь польза mock , так как мне все равно придется вызывать predict метод для сравнения process.return_value . Любые предложения будут полезны

Комментарии:

1. Ну, это зависит от того, что вы тестируете. Если вы хотите проверить предсказание, вы не можете издеваться над ним, это противоречило бы цели. Если вы хотите протестировать tf_process без тестирования predict , predict вместо этого издевайтесь. С вашим текущим издевательством вы можете проверить только те predict вызовы tf_process с правильным аргументом.

2. @MrBeanBremen Я просто хочу проверить, создана ли колонка tf_predictions ? Любые предложения о том, как это проверить.

3. Как я уже писал — в этом случае вы можете издеваться predict (не зная, к какому классу он принадлежит), чтобы вернуть что-то разумное.

Ответ №1:

Я думаю, что в вашем случае это лучше использовать Mock() . Вы можете создавать действительно хорошие и простые тесты и без patch() этого . Просто подготовьте все необходимые издевательские экземпляры для инициализации.

 from unittest.mock import Mock


class TestSample(TestCase):
    def test_predict(self):
        # let's say predict() will return something... just an example
        tf = Mock(predict=Mock(return_value=(10, 20, 30)))
        df = pd.DataFrame({'test_col': (1, 2, 3)})
        df = Sample(tf).predict(df)
        # check column
        self.assertTrue('tf_predictions' in df.columns)
        # or check records
        self.assertEqual(
            df.to_dict('records'),
            [
                {'test_col': 1, 'tf_predictions': 10},
                {'test_col': 2, 'tf_predictions': 20},
                {'test_col': 3, 'tf_predictions': 30}
            ]
        )
 

Также это действительно помогает, когда вам нужны тесты для сложных сервисов. Просто пример:

 class ClusterService:
    def __init__(self, service_a, service_b, service_c) -> None:
        self._service_a = service_a
        self._service_b = service_b
        self._service_c = service_c
        # service_d, ... etc

    def get_cluster_info(self, name: str):
        self._service_a.send_something_to_somewhere(name)
        data = {
            'name': name,
            'free_resources': self._service_b.get_free_resources(),
            'current_price': self._service_c.get_price(name),
        }

        return ' ,'.join([
            ': '.join(['Cluster name', name]),
            ': '.join(['CPU', str(data['free_resources']['cpu'])]),
            ': '.join(['RAM', str(data['free_resources']['ram'])]),
            ': '.join(['Price', '{} 



.format(round(data['current_price']['usd'], 2))]),
])

class TestClusterService(TestCase):
def test_get_cluster_info(self):
cluster = ClusterService(
service_a=Mock(),
service_b=Mock(get_free_resources=Mock(return_value={'cpu': 100, 'ram': 200})),
service_c=Mock(get_price=Mock(return_value={'usd': 101.4999})),
)

self.assertEqual(
cluster.get_cluster_info('best name'),
'Cluster name: best name ,CPU: 100 ,RAM: 200 ,Price: 101.5

)