Source code for langml.model

# -*- coding: utf-8 -*-

import os
import random
import string
from typing import Any

import tensorflow as tf
from langml import TF_KERAS
if TF_KERAS:
    import tensorflow.keras.backend as K
else:
    import keras.backend as K

from langml.tensor_typing import Models
from langml.log import info, warn
from langml import TF_VERSION


[docs]SAVED_MODEL_TAG = 'serve'
[docs]def get_random_string(length): return ''.join(random.choice(string.ascii_lowercase) for _ in range(length))
[docs]def export_model_v1(model, export_model_dir): """ :param export_model_dir: type string, save dir for exported model url :param model_version: type int best :return:no return """ if os.path.exists(export_model_dir): warn(f'path `{export_model_dir}` exists!') export_model_dir = f"{export_model_dir}.{get_random_string(6)}" warn(f'auto relocation to `{export_model_dir}`') os.makedirs(export_model_dir) with tf.get_default_graph().as_default(): info(f"input: {model.input}") info(f"output: {model.output}") input_map = {} if isinstance(model.input, (tuple, list)): for x in model.input: input_map[x.name.split(':')[0]] = tf.saved_model.build_tensor_info(x) else: input_map[model.input.name.split(':')[0]] = tf.saved_model.build_tensor_info(model.input) info(f'input map: {input_map}') output_map = {} if isinstance(model.output, (tuple, list)): for x in model.output: output_map[x.name.split(':')[0]] = tf.saved_model.build_tensor_info(x) else: output_map[model.output.name.split(':')[0]] = tf.saved_model.build_tensor_info(model.output) info(f'output map: {output_map}') prediction_signature = ( tf.saved_model.build_signature_def( inputs=input_map, outputs=output_map) ) info('step1 => prediction_signature created successfully') builder = tf.saved_model.builder.SavedModelBuilder(export_model_dir) builder.add_meta_graph_and_variables( sess=K.get_session(), tags=[SAVED_MODEL_TAG], signature_def_map={ 'predict': prediction_signature, 'serving_default': prediction_signature, }, ) info(f'step2 => Export path({export_model_dir}) ready to export trained model') builder.save() info(f'done! model has saved to {export_model_dir}.')
[docs]def save_frozen(model: Models, fpath: str): if int(tf.__version__.split('.')[0]) > 1: tf.saved_model.save(model, fpath) else: info('apply tensorflow 1.x frozen') export_model_v1(model, fpath)
[docs]def load_frozen(model_dir: str, session: Any = None) -> Any: if TF_VERSION > 1: return tf.saved_model.load(model_dir) if session is None: raise ValueError('session is required in tensorflow 1.x') tf.saved_model.loader.load(session, [SAVED_MODEL_TAG], export_dir=model_dir) info('done! session has restored.') return session