#python
#python
Вопрос:
Я использую код, который нашел на Github, и пытаюсь его запустить. Однако есть ошибка, которая гласит:
File "/home/anatole2/WVD/SymJAX/symjax/tensor/base.py", line 253, in jax_wrap
symjax._fn_to_op[func] = op
AttributeError: module 'symjax' has no attribute '_fn_to_op'
Я пытался найти _fn_to_op
в каждом файле Python и смог найти только это в __init__.py
.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# flake8: noqa
_fn_to_op = {}
from . import probabilities
from . import nn
from . import tensor
from . import data
from . import rl
from .base import *
from ._version import get_versions
__version__ = get_versions()["version"]
del get_versions
_graphs = [Graph("default")]
__all__ = ["data", "tensor", "nn", "probabilities", "rl"]
Когда я изменил путь на symjax.__init__._fn_to_op
, я получаю ту же ошибку. Я знаю, _fn_to_op
что в этом файле нет функции / атрибута, но это единственное, что я нашел. Что делает функция и как я могу исправить ошибку?
Вот код, который генерирует ошибку:
def jax_wrap(func, doc_func=None):
if doc_func is None:
doc_func = func
@wraps(doc_func)
def op(*args, seed=None, **kwargs):
# if there is a name we remove it for now to use the jax tracer
op_name = kwargs.pop("name", None)
# first we check if we are in a random function to be careful
# with the key. this is just to get shape and dtype so we do not bother
# to use the correct seed yet
is_random = func in symjax.tensor.random._RANDOM_FUNCTIONS
temp_args = ((jax.random.PRNGKey(0),) if is_random else ()) args
tree = get_output_tree(func, *temp_args, **kwargs)
# now we determine what type of Tensor subclass it will produce
feed = {"_jax_function": func, "name": op_name}
if type(tree) == list or type(tree) == tuple:
feed.update(
{
"_shapes": [t.shape for t in tree],
"_dtypes": [t.dtype for t in tree],
}
)
if func == jax.numpy.shape:
return Shape(
*args,
**feed,
**kwargs,
)
else:
return MultiOutputOp(
*args,
**feed,
**kwargs,
)
else:
feed.update({"_shape": tree.shape, "_dtype": tree.dtype})
if is_random:
return RandomOp(*args, _seed=seed, **feed, **kwargs)
else:
return Op(*args, **feed, **kwargs)
symjax._fn_to_op[func] = op
if not hasattr(func, "__doc__") or func.__doc__ is None:
return op
if doc_func is not None:
op.__name__ = func.__name__
return op
Комментарии:
1. Вы пытались связаться с автором кода?
2. Можете ли вы добавить код, который генерирует ошибку?