web-dev-qa-db-de.com

Wie finde ich die Variablennamen und -werte, die in einem Prüfpunkt gespeichert sind?

Ich möchte die Variablen sehen, die in einem TensorFlow-Prüfpunkt zusammen mit ihren Werten gespeichert werden. Wie finde ich die Variablennamen, die in einem TensorFlow-Prüfpunkt gespeichert sind?

Ich habe tf.train.NewCheckpointReader verwendet, der hier erklärt wird. Es ist jedoch nicht in der Dokumentation von TensorFlow enthalten. Gibt es einen anderen Weg?

23
Tavakoli

Sie können das Tool inspect_checkpoint.py verwenden.

Wenn Sie beispielsweise den Prüfpunkt im aktuellen Verzeichnis gespeichert haben, können Sie die Variablen und ihre Werte wie folgt drucken

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
11
keveman

Verwendungsbeispiel:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

Update:all_tensors-Argument wurde zu print_tensors_in_checkpoint_file seit Tensorflow 0.12.0-rc0 hinzugefügt. Daher müssen Sie ggf. all_tensors=False oder all_tensors=True hinzufügen.

Alternative Methode:

from tensorflow.python import pywrap_tensorflow
import os

checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

Ich hoffe es hilft.

36
sagunms

Noch ein paar Details.

Wenn Ihr Modell im V2-Format gespeichert wird, wenn sich zum Beispiel folgende Dateien im Verzeichnis /my/dir/ befinden

model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

dann sollte der file_name-Parameter nur das Präfix sein

print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)

Siehe https://github.com/tensorflow/tensorflow/issues/7696 für eine Diskussion.

9
deeplearning

Hinzufügen weiterer Parameterdetails zu print_tensors_in_checkpoint_file

file_name: keine physische Datei, sondern nur das Präfix der Dateinamen

Wenn kein tensor_name angegeben ist, werden die Tensornamen und -formen in der Checkpoint-Datei ausgegeben. Wenn tensor_name angegeben ist, wird der Inhalt des Tensors gedruckt. ( inspect_checkpoint.py )

Wenn all_tensor_namesTrue ist, werden alle Tensornamen gedruckt

Wenn all_tensor "True" ist, werden alle Tensornamen und der entsprechende Inhalt gedruckt.

N.B.all_tensor und all_tensor_names überschreiben tensor_name

0
BugKiller