aboutsummaryrefslogtreecommitdiff
path: root/utils.py
diff options
context:
space:
mode:
authorBen Cohen <ben@kensho.com>2019-07-28 13:51:38 -0400
committerBen Cohen <ben@kensho.com>2019-07-28 13:51:38 -0400
commite92085f124cb5f16edfc4c19dac1e6823d4838b7 (patch)
tree62a90ab516447efc930ae130008efe572f96ccd8 /utils.py
parent52b6c48da443a6fba35dec68348499640944790d (diff)
returning
Diffstat (limited to 'utils.py')
-rw-r--r--utils.py48
1 files changed, 28 insertions, 20 deletions
diff --git a/utils.py b/utils.py
index 7257c8a..dd7afce 100644
--- a/utils.py
+++ b/utils.py
@@ -107,20 +107,24 @@ def get_closest_to_point(one_embed, style_limit=[]):
else:
possible_beers = small_embeddings
print(len(possible_beers))
+
+ to_return = []
for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]:
# print(b, a)
if thing[0] in beer_names:
bn = beer_names[thing[0]]
- print(bn[0], '---', bn[1][0])
- print('=' * 50)# In[16]:
+
+ to_return.append((thing[0], bn[0], bn[1][0]))
+ return to_return
-def translate_to_attr(embedding, to_attr):
- if to_attr.startswith('-'):
+
+def translate_to_attr(embedding, to_attr, amt):
+ if amt < 0:
back = True
- to_attr = to_attr[1:]
+ amt = abs(amt)
else:
back = False
@@ -138,14 +142,19 @@ def translate_to_attr(embedding, to_attr):
print('=' *10)
print('\n')
dist_between = np.linalg.norm(closest_center_vector-embedding)
- for x in [4,2, 1]:
- print('Moving {}%'.format(1/float(x) * 100))
- if back:
- new_point = embedding - (vector_between/x/5)
- else:
- new_point = embedding + (vector_between/x)
- get_closest_to_point(new_point)
- print('-' * 25)
+ amt_dict = {
+ 1: 4,
+ 2: 2,
+ 3: 1
+ }
+ x = amt_dict[amt]
+ print('Moving {}%'.format(1/float(x) * 100))
+ if back:
+ new_point = embedding - (vector_between/x)
+ else:
+ new_point = embedding + (vector_between/x)
+ return get_closest_to_point(new_point)
+
def translate_to_style(embedding, style):
@@ -157,12 +166,11 @@ def translate_to_style(embedding, style):
print('=' *10)
print('\n')
dist_between = np.linalg.norm(closest_center_vector-embedding)
- for x in [4,2, 1]:
- print('Moving {}%'.format(1/float(x) * 100))
- new_point = embedding + (vector_between/x)
- get_closest_to_point(new_point, style_limit=style)
- print('-' * 25)
+ TRANSLATION_CONSTANT = 4 # 4 is slight
+
+ new_point = embedding + (vector_between/TRANSLATION_CONSTANT)
+ return get_closest_to_point(new_point, style_limit=style)
# In[17]:
@@ -177,7 +185,7 @@ info = beer_names[beer_id]
print('Real ABV: {}'.format(info[2]))
one_embed = final_data[beer_id]['embed']
-translate_to_style(one_embed, 'Belgian Quadrupel (Quad)')
+print(translate_to_style(one_embed, 'Belgian Quadrupel (Quad)'))
# In[22]:
@@ -192,7 +200,7 @@ info = beer_names[beer_id]
print('Real ABV: {}'.format(info[2]))
one_embed = final_data[beer_id]['embed']
-translate_to_attr(one_embed, 'Hoppy')
+print(translate_to_attr(one_embed, 'Hoppy', 1))
# In[ ]: