1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
| import json import traceback import numpy as np from keras.models import load_model from keras_bert import get_custom_objects from keras_contrib.layers import CRF from keras_contrib.losses import crf_loss from keras_contrib.metrics import crf_accuracy from flask import Flask, request
from model_train import PreProcessInputData, id_label_dict
def bio_to_json(string, tags): item = {"string": string, "entities": []} entity_name = "" entity_start = 0 iCount = 0 entity_tag = ""
for c_idx in range(min(len(string), len(tags))): c, tag = string[c_idx], tags[c_idx] if c_idx < len(tags)-1: tag_next = tags[c_idx+1] else: tag_next = ''
if tag[0] == 'B': entity_tag = tag[2:] entity_name = c entity_start = iCount if tag_next[2:] != entity_tag: item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]}) elif tag[0] == "I": if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O': tags[c_idx] = 'O' pass else: entity_name = entity_name + c if tag_next[2:] != entity_tag: item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag}) entity_name = '' iCount += 1 return item
app = Flask(__name__)
@app.route("/model/ner", methods=["GET", "POST"]) def get_geo(): return_result = {"code": 200, "message": "success", "data": []} try: text = request.get_json()["text"].replace(" ", "") word_labels, seq_types = PreProcessInputData([text])
predicted = ner_model.predict([word_labels, seq_types]) y = np.argmax(predicted[0], axis=1) tag = [id_label_dict[_] for _ in y]
result = bio_to_json(text, tag[1:-1]) return_result["data"] = result
except Exception: return_result["code"] = 400 return_result["message"] = traceback.format_exc()
return json.dumps(return_result, ensure_ascii=False, indent=2)
if __name__ == '__main__': custom_objects = get_custom_objects() for key, value in {'CRF': CRF, 'crf_loss': crf_loss, 'crf_accuracy': crf_accuracy}.items(): custom_objects[key] = value ner_model = load_model("example_ner.h5", custom_objects=custom_objects) app.run(host="0.0.0.0", port=25000)
|