Вывод с помощью обслуживания Tensorflow с использованием Java

#tensorflow #tensorflow-serving #grpc-java

#tensorflow #обслуживание tensorflow #grpc-java

Вопрос:

Мы переводим существующий производственный код Java на использование Tensorflow Serving (TFS) для вывода. Мы уже переподготовили наши модели и сохранили их, используя новый формат SavedModel (больше никаких замороженных графиков !!).
Из документации, которую я прочитал, следует, что TFS напрямую не поддерживает Java. Однако он предоставляет интерфейс gRPC, и это действительно обеспечивает интерфейс Java.

Мой вопрос: каковы шаги, связанные с подготовкой Java-приложения к использованию TFS.

[Правка: перенесенные шаги к решению]

Ответ №1:

Потребовалось четыре дня, чтобы собрать все воедино, поскольку документация и примеры по-прежнему ограничены.
Я уверен, что есть лучшие способы сделать это, но это то, что я нашел до сих пор:

  • Я клонировал файлы tensorflow/tensorflow tensorflow/serving и google/protobuf репозитории на github.
  • Я скомпилировал следующие файлы protobuf, используя protoc компилятор protobuf с grpc-java плагином. Я ненавижу тот факт, что нужно скомпилировать так много разрозненных .proto файлов, но я хотел включить минимальный набор, а в различных каталогах так много ненужных .proto файлов, которые были бы втянуты. Вот минимальный набор, который мне нужен для компиляции нашего Java-приложения:
    • serving_repo/tensorflow_serving/apis/*.proto
    • serving_repo/tensorflow_serving/config/model_server_config.proto
    • serving_repo/tensorflow_serving/core/logging.proto
    • serving_repo/tensorflow_serving/core/logging_config.proto
    • serving_repo/tensorflow_serving/util/status.proto
    • serving_repo/tensorflow_serving/sources/storage_path/file_system_storage_path_source.proto
    • serving_repo/tensorflow_serving/config/log_collector_config.proto
    • tensorflow_repo/tensorflow/core/framework/tensor.proto
    • tensorflow_repo/tensorflow/core/framework/tensor_shape.proto
    • tensorflow_repo/tensorflow/core/framework/types.proto
    • tensorflow_repo/tensorflow/core/framework/resource_handle.proto
    • tensorflow_repo/tensorflow/core/example/example.proto
    • tensorflow_repo/tensorflow/core/protobuf/tensorflow_server.proto
    • tensorflow_repo/tensorflow/core/example/feature.proto
    • tensorflow_repo/tensorflow/core/protobuf/named_tensor.proto
    • tensorflow_repo/tensorflow/core/protobuf/config.proto
  • Обратите внимание, что это protoc будет компилироваться даже без grpc-java present, однако большинство критических точек входа будут таинственным образом отсутствовать. Если PredictionServiceGrpc.java отсутствует, то grpc-java не выполняется.
  • Пример командной строки (со вставленными разрывами строк для удобства чтения):
 $ ./protoc -I=/Users/foobar/protobuf_repo/src 
   -I=/Users/foobar/tensorflow_repo    
   -I=/Users/foobar/tfserving_repo   
   -plugin=protoc-gen-grpc-java=/Users/foobar/protoc-gen-grpc-java-1.20.0-osx-x86_64.exe 
   --java_out=src 
   --grpc-java_out=src 
   /Users/foobar/tfserving_repo/tensorflow_serving/apis/*.proto
 
 ManagedChannel mChannel;
PredictionServiceGrpc.PredictionServiceBlockingStub mBlockingstub;
mChannel = ManagedChannelBuilder.forAddress(host,port).usePlaintext().build();
mBlockingstub = PredictionServiceGrpc.newBlockingStub(mChannel);
 
  • Я следил за несколькими документами, чтобы собрать воедино следующие шаги:
    • В документах gRPC обсуждаются заглушки (блокировка и синхронизация)
    • В этой статье рассматривается процесс, но с использованием Python
    • Этот пример кода имел решающее значение для примеров синтаксиса newBuilder.
  • Импорт Maven:
    • io.grpc:grpc-all
    • org.tensorflow:libtensorflow
    • org.tensorflow:proto
    • com.google.protobuf:protobuf-java
  • Вот пример кода:
 // Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

TensorShapeProto.Dim featuresDim1  = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto     featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();


// Now prepare for the inference request over gRPC to the TF Serving server
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder().setValue(mGraphVersion).build();

Model.ModelSpec.Builder model = Model.ModelSpec
                                     .newBuilder()
                                     .setName(mGraphName)
                                     .setVersion(version);  // type = Int64Value
Model.ModelSpec     modelSpec = model.build();

Predict.PredictRequest request;
request = Predict.PredictRequest.newBuilder()
                                .setModelSpec(modelSpec)
                                .putInputs("image", featuresTensorProto)
                                .build();

Predict.PredictResponse response;

try {
    response = mBlockingstub.predict(request);
    // Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java

    java.util.Map<java.lang.String, org.tensorflow.framework.TensorProto> outputs = response.getOutputsOrDefault();
    for (java.util.Map.Entry<java.lang.String, org.tensorflow.framework.TensorProto> entry : outputs.entrySet()) {
        System.out.println("Response with the key: "   entry.getKey()   ", value: "   entry.getValue());
    }
} catch (StatusRuntimeException e) {
    logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
    success = false;
}