Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.3k views
in Technique[技术] by (71.8m points)

python - Tensorflow: delete nodes from graph

I'm trying to delete some nodes from graph and save it in .pb

Only needed nodes can be added to new mod_graph_def graph, but the problem that graph still have some references to deleted node in other nodes inputs, but I can't modify inputs of node:

def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # The problem that graph still have some references to deleted node in other nodes inputs
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        node.input = inp_names # TypeError: Can't set composite field

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)
def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Delete nodes
    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # Delete references to deleted nodes
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        del node.input[:]
        node.input.extend(inp_names)

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...