diff options
Diffstat (limited to 'utils.py')
-rw-r--r-- | utils.py | 48 |
1 files changed, 28 insertions, 20 deletions
@@ -107,20 +107,24 @@ def get_closest_to_point(one_embed, style_limit=[]): else: possible_beers = small_embeddings print(len(possible_beers)) + + to_return = [] for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]: # print(b, a) if thing[0] in beer_names: bn = beer_names[thing[0]] - print(bn[0], '---', bn[1][0]) - print('=' * 50)# In[16]: + + to_return.append((thing[0], bn[0], bn[1][0])) + return to_return -def translate_to_attr(embedding, to_attr): - if to_attr.startswith('-'): + +def translate_to_attr(embedding, to_attr, amt): + if amt < 0: back = True - to_attr = to_attr[1:] + amt = abs(amt) else: back = False @@ -138,14 +142,19 @@ def translate_to_attr(embedding, to_attr): print('=' *10) print('\n') dist_between = np.linalg.norm(closest_center_vector-embedding) - for x in [4,2, 1]: - print('Moving {}%'.format(1/float(x) * 100)) - if back: - new_point = embedding - (vector_between/x/5) - else: - new_point = embedding + (vector_between/x) - get_closest_to_point(new_point) - print('-' * 25) + amt_dict = { + 1: 4, + 2: 2, + 3: 1 + } + x = amt_dict[amt] + print('Moving {}%'.format(1/float(x) * 100)) + if back: + new_point = embedding - (vector_between/x) + else: + new_point = embedding + (vector_between/x) + return get_closest_to_point(new_point) + def translate_to_style(embedding, style): @@ -157,12 +166,11 @@ def translate_to_style(embedding, style): print('=' *10) print('\n') dist_between = np.linalg.norm(closest_center_vector-embedding) - for x in [4,2, 1]: - print('Moving {}%'.format(1/float(x) * 100)) - new_point = embedding + (vector_between/x) - get_closest_to_point(new_point, style_limit=style) - print('-' * 25) + TRANSLATION_CONSTANT = 4 # 4 is slight + + new_point = embedding + (vector_between/TRANSLATION_CONSTANT) + return get_closest_to_point(new_point, style_limit=style) # In[17]: @@ -177,7 +185,7 @@ info = beer_names[beer_id] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] -translate_to_style(one_embed, 'Belgian Quadrupel (Quad)') +print(translate_to_style(one_embed, 'Belgian Quadrupel (Quad)')) # In[22]: @@ -192,7 +200,7 @@ info = beer_names[beer_id] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] -translate_to_attr(one_embed, 'Hoppy') +print(translate_to_attr(one_embed, 'Hoppy', 1)) # In[ ]: |