Ich möchte die Variablen sehen, die in einem TensorFlow-Prüfpunkt zusammen mit ihren Werten gespeichert werden. Wie finde ich die Variablennamen, die in einem TensorFlow-Prüfpunkt gespeichert sind?
Ich habe tf.train.NewCheckpointReader
verwendet, der hier erklärt wird. Es ist jedoch nicht in der Dokumentation von TensorFlow enthalten. Gibt es einen anderen Weg?
Sie können das Tool inspect_checkpoint.py
verwenden.
Wenn Sie beispielsweise den Prüfpunkt im aktuellen Verzeichnis gespeichert haben, können Sie die Variablen und ihre Werte wie folgt drucken
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
Verwendungsbeispiel:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
Update:all_tensors
-Argument wurde zu print_tensors_in_checkpoint_file
seit Tensorflow 0.12.0-rc0 hinzugefügt. Daher müssen Sie ggf. all_tensors=False
oder all_tensors=True
hinzufügen.
Alternative Methode:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
Ich hoffe es hilft.
Noch ein paar Details.
Wenn Ihr Modell im V2-Format gespeichert wird, wenn sich zum Beispiel folgende Dateien im Verzeichnis /my/dir/
befinden
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
dann sollte der file_name
-Parameter nur das Präfix sein
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
Siehe https://github.com/tensorflow/tensorflow/issues/7696 für eine Diskussion.
Hinzufügen weiterer Parameterdetails zu print_tensors_in_checkpoint_file
file_name
: keine physische Datei, sondern nur das Präfix der Dateinamen
Wenn kein tensor_name
angegeben ist, werden die Tensornamen und -formen in der Checkpoint-Datei ausgegeben. Wenn tensor_name
angegeben ist, wird der Inhalt des Tensors gedruckt. ( inspect_checkpoint.py )
Wenn all_tensor_names
True
ist, werden alle Tensornamen gedruckt
Wenn all_tensor
"True" ist, werden alle Tensornamen und der entsprechende Inhalt gedruckt.
N.B.all_tensor
und all_tensor_names
überschreiben tensor_name