#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Aug 12 10:19:31 2021 @author: cuxac """ from scipy.spatial import cKDTree as KDTree from scipy.special import softmax import numpy as np import sys import json from collections import defaultdict#,Counter import operator from more_itertools import locate import statistics import pickle import fasttext model=fasttext.load_model("./v1/en/modelhal0EN2.bin") D=np.load("./v1/en/HALen_matrixKDT.npy") dico_nd_pkl=open("./v1/en/dico_nd.pkl","rb") dico_nd=pickle.load(dico_nd_pkl) verb_class={"spi":"Sciences de l'ingénieur [physics]","shs":"Sciences de l'Homme et Société","sdv":"Sciences du Vivant [q-bio]","sdu":"Planète et Univers [physics]","sde":"Sciences de l'environnement","scco":"Sciences cognitives","phys":"Physique [physics]","nlin":"Science non linéaire [physics]","math":"Mathématiques [math]","info":"Informatique [cs]","chim":"Chimie","stat":"Statistiques","qfin":"Économie et finance quantitative [q-fin]"} kdtree=KDTree(D) distlist=[] list_defis=[] n=0 K=50 for line in sys.stdin: data = json.loads(line) text=data['value'] mv=model.get_sentence_vector(text.strip()) ppv=kdtree.query(mv,k=K,p=2)#,distance_upper_bound=0.05) dmax_k=kdtree.query(mv,k=[K],p=2)[0][0] dN_ppv=ppv[0]/dmax_k distlist.append(ppv[0][0]/dmax_k) list_defis=list(ppv[1]) list_defis_label=list(dico_nd[i] for i in list_defis) r=zip(list_defis_label,list(ppv[0])) dis=defaultdict(list) for i in set(r): dis[i[0]].append(1/i[1]) for k in dis.keys(): dis[k]=sum(dis[k]) ddd=np.array(list(dis.values())) sm=softmax(ddd) lab=list(dis.keys()) res=zip(lab,sm) d4=dict(res) classmax=max(d4.items(),key=operator.itemgetter(1))[0] cmax=classmax.split('_')[1].split('.')[0] if ppv[0][0]<100: mm=max(list_defis_label,key=list_defis_label.count) ind=list(locate(list_defis_label, lambda a: a ==mm)) indd=list(dN_ppv) dmax=[indd[i] for i in ind] dist_mean=statistics.mean(dmax) prob0=round((len(ind))/(K+len(set(list_defis_label))),3) prob=d4[classmax] if len(ind)<25: dist_mean=str(round(dist_mean,3))+' / '+str(len(ind)) data['value']=verb_class[cmax]#,prob sys.stdout.write(json.dumps(data)) sys.stdout.write('\n') n+=1