diff --git a/astro-ner/v1/find-astro.py b/astro-ner/v1/find-astro.py index f63c249..9238281 100755 --- a/astro-ner/v1/find-astro.py +++ b/astro-ner/v1/find-astro.py @@ -1,22 +1,26 @@ +#!/usr/bin/env python3 # -*- coding: utf-8 -*- import sys import json from flair.models import SequenceTagger from flair.data import Sentence -from unidecode import unidecode +from unidecode import unidecode import logging -logging.getLogger('flair').handlers[0].stream = sys.stderr +logging.getLogger("flair").handlers[0].stream = sys.stderr + def data_normalization(sentence): cpy_sentence = sentence.lower() return cpy_sentence -tagger = SequenceTagger.load("model.pt") + + +tagger = SequenceTagger.load("v1/model.pt") for line in sys.stdin: data = json.loads(line) - text=data['value'] + text = data["value"] PL = [] TNQ = [] SNAT = [] @@ -35,17 +39,51 @@ SR = [] sent = data_normalization(text) sentS = sent.split(".") - sentences = [Sentence(sentS[i]+".") for i in range(len(sentS))] + sentences = [Sentence(sentS[i] + ".") for i in range(len(sentS))] tagger.predict(sentences) - label_lists = {"PL": PL,"TNQ": TNQ,"SNAT": SNAT,"OA": OA,"SSO": SSO,"EB": EB,"ET": ET,"NRA": NRA,"CST": CST,"GAL": GAL,"AST": AST,"ST": ST,"AS": AS,"SN": SN,"XPL": XPL,"SR": SR} + label_lists = { + "PL": PL, + "TNQ": TNQ, + "SNAT": SNAT, + "OA": OA, + "SSO": SSO, + "EB": EB, + "ET": ET, + "NRA": NRA, + "CST": CST, + "GAL": GAL, + "AST": AST, + "ST": ST, + "AS": AS, + "SN": SN, + "XPL": XPL, + "SR": SR, + } for sentence in sentences: - for entity in sentence.get_spans('ner'): + for entity in sentence.get_spans("ner"): label_value = entity.labels[0].value if entity.text not in label_lists.get(label_value, []): - label_lists[label_value].append(entity.text) - - returnDic = {unidecode('Planète'):PL,unidecode('Trou noirs, quasars et apparentés'):TNQ,'Satellite naturel':SNAT,'Objets artificiels':OA,unidecode('Système solaire') :SSO,unidecode('Étoiles binaires (et pulsars)'):EB,unidecode('Étoiles'):ET,unidecode('Nébuleuse et région apparentés'):NRA,'Constellations':CST,'Galaxies et amas de galaxie':GAL,unidecode('Astèroïdes'):AST,unidecode('Satue hypotétique'):ST,'amas stellaires':AS,'supernovas':SN,unidecode('exoplanètes'):XPL,'sursaut radio, source radio, autres sursauts':SR} + label_lists[label_value].append(entity.text) + + returnDic = { + unidecode("Planète"): PL, + unidecode("Trou noirs, quasars et apparentés"): TNQ, + "Satellite naturel": SNAT, + "Objets artificiels": OA, + unidecode("Système solaire"): SSO, + unidecode("Étoiles binaires (et pulsars)"): EB, + unidecode("Étoiles"): ET, + unidecode("Nébuleuse et région apparentés"): NRA, + "Constellations": CST, + "Galaxies et amas de galaxie": GAL, + unidecode("Astèroïdes"): AST, + unidecode("Satue hypotétique"): ST, + "amas stellaires": AS, + "supernovas": SN, + unidecode("exoplanètes"): XPL, + "sursaut radio, source radio, autres sursauts": SR, + } # ajouter unidecode - data['value'] = {id:value for id, value in returnDic.items() if value != []} + data["value"] = {id: value for id, value in returnDic.items() if value != []} sys.stdout.write(json.dumps(data)) - sys.stdout.write('\n') \ No newline at end of file + sys.stdout.write("\n")