The Session graph is empty解决(TensorFlow2.0)

  • 时间:
  • 浏览:
  • 来源:互联网

最近看了一个关于关于四种花分类的tensorflow代码,第一部分是制作TFRecords数据。但对于tensorflow2.0的版本来看,会出现很多bug。例如关于TFRecord数据保存问题,各种函数的调用问题,最绝的是关于The Session graph is empty的错误。因此对代码做出了很多调整,该代码是面对Tensorflow2.0版本,对于其中问题,说一下自己的看法。

首先说一下最近学的TFRecord,这部分是最近自己正在学的, 主要的是掌握固定的数据代码格式,这里可以看这篇文章TFRecord代码实例

关于源代码的 Session部分:

init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:  # 开始一个会话
        sess.run(init_op)

即使在改写成tf.compat.v1.Session,还是会报错,显示The Session graph is empty
关于书中对Tensorflow2.0版本的讲解是:

①在TensorFlow1.x中,最常规的是使用"session.run()"方法执行计算图,“session.run()"方法的调用类似于函数调用,指定输入数据和调用的方法,最后返回结果。
②TensorFlow2.0将Eager Execcution(动态图机制)作为默认模式。在该模式下用户能够轻松地编写和调试代码,可以使用原生的Python控制语句,大大降低了学习和使用TensorFlow的门槛
③在TensorFlow2.0中,图(graph)和会话(Session)都会变成底层实现,而不需要用户关心。
④TensorFlow2.0推出了一个新的运行理念,即AutoGraph。当运行代码时,TensorFlow2.0自动调用Eager模式执行函数,这是内部所完成的。

下面是解决The Session graph is empty.TensorFlow2.0中对Tensorflow1.X的Seesion理解

1.在一个graph上添加节点和连线,也就是构建我们要用的模型。我们将graph理解为一张空白的设计图,而我们在这张纸上画圈画线。
2.在tf.Session()中运行我们的在某一graph上运行的model。可以理解为运用我们画好的设计图纸,创建出一个机器,实现我们需求的功能。因为只有图纸是做不了事情的。
————————————————
版权声明:本文为CSDN博主「qq_35630121」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_35630121/article/details/103349528

以上几点其实可以帮助我们解决问题。我想的是如果是动态图机制,而我的操作又没有写到图中,那我就自己建一张graph不就行了,因此写了以下代码:

g = tf.Graph()
with g.as_default():
    create_record()
    batch = read_and_decode('flower_train.tfrecords')
    print("*********************")
    print(batch)
    print("***********************")
    init_op1 = tf.compat.v1.local_variables_initializer()
    init_op2=tf.compat.v1.global_variables_initializer()
    with tf.compat.v1.Session() as sess:# 开始一个会话
        sess.run(init_op1)
        sess.run(init_op2)
        for i in range(num_samples):
            dataset= sess.run(batch)
            for seq ,lable in dataset:
                print(seq.numpy())
                print(lable.numpy())
                img = Image.fromarray(seq.numpy(), 'RGB')  # 这里Image是之前提到的
                img.save(gen_picture + '/' + str(i) + 'samples' + str(lable.numpy()) + '.jpg')  # 存下图片;注意cwd后边加上‘/’
        sess.close()

运行结果如下。可以看出batch = read_and_decode('flower_train.tfrecords')
是直接运行的,但这一步不是只是写在图纸上,到session.run()时调用函数才会运行吗?

*********************
<MapDataset shapes: ((64, 64, 3), ()), types: (tf.uint8, tf.int32)>
***********************

对此我们需要知道,TensorFlow1.x版本之所以要分为两步进行,是因为它无法使用Python支持的常用代码,所以需要将Python代码进行编译为TensorFlow所内置的API和函数才能执行。而2.0版本可以使用"tf.function"来修饰Python函数,以将其标记为即时编译,从而TensorFlow可以将其作为单个图来执行。
所以是不是只用正常的Python代码的格式运行TensorFlow就可以了?!(自己的一点点理解,很可能是错的。。先说对不起==)不过确实是可以运行的。

数据链接:https://pan.baidu.com/s/1hPs0XeFbZ9gjL8Z6Rb7SLA
提取码:60c9

修改后代码如下:

# 将原始图片转换成需要的大小,并将其保存
# ========================================================================================
import os
import tensorflow as tf
from PIL import Image
import numpy as np
# 原始图片的存储位置
orig_picture = 'input_data'

# 生成图片的存储位置
gen_picture = 'newinput_data'

# 需要的识别类型
classes = {'dandelion', 'roses', 'sunflowers','tulips'}

# 样本总数
num_samples = 4000


# 制作TFRecords数据
def create_record():
    writer = tf.io.TFRecordWriter("flower_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = orig_picture + "/" + name + "/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))  # 设置需要转换的图片大小
            img_raw = img.tobytes()  # 将图片转化为原生bytes
            # print(index, img_raw)
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
            writer.write(example.SerializeToString())
    writer.close()

def single_example_parser(x):
    features = tf.io.parse_single_example(
        x, features={
            'label': tf.io.FixedLenFeature([], tf.int64),
            'img_raw': tf.io.FixedLenFeature([], tf.string)
        }
    )
    label = features['label']
    img = features['img_raw']
    img = tf.io.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    print(img)
    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    print(label)
    return img, label


# =======================================================================================
def read_and_decode(filename):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.data.TFRecordDataset([filename])
    # # create a reader from file queue
    # reader = tf.compat.v1.TFRecordReader()
    # # reader从文件队列中读入一个序列化的样本
    # _, serialized_example = reader.read(filename_queue)
    # # get feature from serialized example
    #
    # print(type(filename_queue))
    filename_queue=filename_queue.map(lambda x:single_example_parser(x))
    # print(type(filename_queue))
    return filename_queue



# =======================================================================================
tf.compat.v1.local_variables_initializer()
tf.compat.v1.global_variables_initializer()
create_record()
for i in range(num_samples):
    dataset = read_and_decode('flower_train.tfrecords')   # 在会话中取出image和label
    # print(dataset)
    # li=list(dataset.as_numpy_iterator())
    # print(np.asarray(li).shape)
    for seq ,lable in dataset:
        print(seq.numpy())
        print(lable.numpy())
        img = Image.fromarray(seq.numpy(), 'RGB')  # 这里Image是之前提到的
        img.save(gen_picture + '/' + str(i) + 'samples' + str(lable.numpy()) + '.jpg')  # 存下图片;注意cwd后边加上‘/’


原始代码如下,链接https://github.com/waitingfordark/four_flower/blob/master/create%20record.py

# 将原始图片转换成需要的大小,并将其保存
# ========================================================================================
import os
import tensorflow as tf
from PIL import Image

# 原始图片的存储位置
orig_picture = 'D:/ML/flower/flower_photos/'

# 生成图片的存储位置
gen_picture = 'D:/ML/flower/input_data/'

# 需要的识别类型
classes = {'dandelion', 'roses', 'sunflowers','tulips'}

# 样本总数
num_samples = 4000


# 制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter("flower_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = orig_picture + "/" + name + "/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))  # 设置需要转换的图片大小
            img_raw = img.tobytes()  # 将图片转化为原生bytes
            print(index, img_raw)
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
            writer.write(example.SerializeToString())
    writer.close()


# =======================================================================================
def read_and_decode(filename):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label


# =======================================================================================
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('flower_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    with tf.Session() as sess:  # 开始一个会话
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(num_samples):
            example, lab = sess.run(batch)  # 在会话中取出image和label
            img = Image.fromarray(example, 'RGB')  # 这里Image是之前提到的
            img.save(gen_picture + '/' + str(i) + 'samples' + str(lab) + '.jpg')  # 存下图片;注意cwd后边加上‘/’
            print(example, lab)
        coord.request_stop()
        coord.join(threads)
        sess.close()

本文链接http://www.dzjqx.cn/news/show-617218.html