Какая точность используется xarray weighted().mean() ?

#python #precision #python-xarray #weighted-average

#python #точность #python-xarray #средневзвешенное значение

Вопрос:

Я хотел бы понять, когда weighted().mean() использует простую или двойную точность.
Вот краткий пример проблемы, с которой я столкнулся:

 import numpy as np
import xarray as xr
data = xr.tutorial.open_dataset("air_temperature")
weights = np.cos( np.deg2rad(data.lat) )   # weights with the cosinus of the latitude

# data are 32 bytes
data.air.dtype   # --> dtype('float32')
weights.dtype    # --> dtype('float32')

# mean() uses 32b
data.air.isel(time=1).mean().dtype   # --> dtype('float32')

# weighted().mean() uses 64b in some cases...
data.air.isel(time=1).weighted(weights).mean().dtype            # --> dtype('float64')
data.air.isel(time=slice(1,2)).weighted(weights).mean().dtype   # --> dtype('float64')

# ... but weighted().mean() keep using 32b if mean is not applied to all dimensions, why?
data.air.isel(time=slice(1,2)).weighted(weights).mean(("lon", "lat")).dtype   # --> dtype('float32')
  

Почему weighted().mean() использует 32 байта вместо 64, если mean() применяется не ко всем измерениям?

Ответ №1:

редактировать: я только что обновил свои среды, и этого больше не произошло. Я не уверен, какая библиотека отвечает (кажется, не numpy)


Это связано с несколько непоследовательным поведением where in xarray. Сравнить

 dta = data.air.isel(time=slice(1,2))

# only one element
dta.where([True]).dtype # -> float64
# many elements
dta.where(dta.lat > 0) # -> float32
  

Я не уверен, почему where поведение отличается в двух ситуациях. Итак, я думаю, вы получаете float64 , когда уменьшаете все измерения.

Это происходит внутри _sum_of_weights .

 sum_of_weights = xr.dot(dta.notnull(), weights)
valid_weights = sum_of_weights != 0.0
print(sum_of_weights.where(valid_weights).dtype) # -> float32


sum_of_weights = xr.dot(dta.notnull(), weights, dims=...)
valid_weights = sum_of_weights != 0.0
print(sum_of_weights.where(valid_weights).dtype) # -> float64