aboutsummaryrefslogtreecommitdiff
path: root/utils.py
diff options
context:
space:
mode:
authorBen Cohen <ben@kensho.com>2019-07-28 13:29:20 -0400
committerBen Cohen <ben@kensho.com>2019-07-28 13:29:20 -0400
commit23d2f73b5f721fd0987ecd127c7cf37fb249a705 (patch)
treeb8f9886603c6b9069feb13bf62eb63451f167109 /utils.py
parent9b2917bfc6d370ea5ff7bf156bb0f63268f8957e (diff)
add utils and data
Diffstat (limited to 'utils.py')
-rw-r--r--utils.py218
1 files changed, 218 insertions, 0 deletions
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..73decb8
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,218 @@
+import re
+import funcy
+import csv
+import numpy as np
+import datetime
+import json
+from collections import defaultdict
+import glob
+from scipy.spatial.distance import cosine
+
+
+with open('data/styles.json') as f:
+ styles = json.load(f)
+
+
+
+with open('data/beer_info.json', 'r') as f:
+ beer_names = json.load(f)
+
+
+attr_to_styles = defaultdict(list)
+style_to_attrs = defaultdict(list)
+with open('data/sparser_beer_data.csv') as f:
+ reader = csv.DictReader(f)
+ for line in reader:
+ style = line.pop('Style')
+ for k, v in line.items():
+ if v == '1':
+ attr_to_styles[k].append(style)
+ style_to_attrs[style].append(k)
+
+
+# In[7]:
+
+with open('data/final_data_small.json') as f:
+ final_data = json.load(f)
+
+
+
+
+failed = 0
+beer_by_style = defaultdict(list)
+for beer, data in final_data.items():
+ try:
+ a, b = beer.split('-')
+ real_style = beer_names[b + '-' + a][1][0]
+ if beer_names[b + '-' + a][3] > 7:
+ beer_by_style[real_style].append(data['embed'])
+
+ except:
+ failed += 1
+
+
+style_centers = {}
+for style, datas in beer_by_style.items():
+ style_centers[style.strip()] = np.mean(datas, axis=0)
+
+
+
+style_name_to_num = {y[0]: x for x,y in styles.items()}
+
+
+
+
+attr_centers = {}
+for attr, rel_styles in attr_to_styles.items():
+ centers = [style_centers[x] for x in rel_styles]
+ attr_mean = np.mean(centers, axis=0)
+ attr_centers[attr] = attr_mean
+
+
+
+embeddings = [(x[0], x[1]['embed']) for x in final_data.items()]
+
+
+
+small_embeddings = []
+small_styles = []
+bad = 0
+for e in embeddings:
+ b, a = e[0].split('-')
+ newkey = a + '-' + b
+ if newkey in beer_names:
+ info= beer_names[newkey]
+
+ if info[3] > 25:
+ small_embeddings.append(e)
+ small_styles.append(info[1][0])
+ else:
+ bad += 1
+
+
+
+def get_closest(beer_id):
+ one_embed = final_data[beer_id]['embed']
+
+
+ for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:25]:
+
+ b, a = thing[0].split('-')
+# print(b, a)
+ if a + '-' + b in beer_names:
+ print(beer_names[a + '-' + b])
+ print('=' * 50)
+
+def get_closest_to_point(one_embed, style_limit=[]):
+# one_embed = final_data[beer_id]['embed']
+
+ if style_limit:
+ possible_beers = []
+ for style, e in zip(small_styles, small_embeddings):
+ if style == style_limit:
+ possible_beers.append(e)
+ else:
+ possible_beers = small_embeddings
+ print(len(possible_beers))
+ for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]:
+
+ b, a = thing[0].split('-')
+# print(b, a)
+ if a + '-' + b in beer_names:
+ bn = beer_names[a + '-' + b]
+ print(bn[0], '---', bn[1][0])
+ print('=' * 50)
+
+
+# In[16]:
+
+
+def translate_to_attr(embedding, to_attr):
+ if to_attr.startswith('-'):
+ back = True
+ to_attr = to_attr[1:]
+
+ else:
+ back = False
+
+ relevant_styles = attr_to_styles[to_attr]
+ small_rel_centers = {x: y for x, y in style_centers.items() if x in relevant_styles}
+
+ sorted_centers = sorted(small_rel_centers, key=lambda x: cosine(small_rel_centers[x], embedding))
+ closest_center = sorted_centers[0]
+ closest_center_vector = style_centers[closest_center]
+
+ vector_between = closest_center_vector - embedding
+
+ print('Moving Towards/From: {}'.format(closest_center))
+ print('=' *10)
+ print('\n')
+ dist_between = np.linalg.norm(closest_center_vector-embedding)
+ for x in [4,2, 1]:
+ print('Moving {}%'.format(1/float(x) * 100))
+ if back:
+ new_point = embedding - (vector_between/x/5)
+ else:
+ new_point = embedding + (vector_between/x)
+ get_closest_to_point(new_point)
+ print('-' * 25)
+
+def translate_to_style(embedding, style):
+
+ closest_center_vector = style_centers[style]
+
+ vector_between = closest_center_vector - embedding
+# print(vector_between)
+ print('Moving Towards/From: {}'.format(style))
+ print('=' *10)
+ print('\n')
+ dist_between = np.linalg.norm(closest_center_vector-embedding)
+ for x in [4,2, 1]:
+ print('Moving {}%'.format(1/float(x) * 100))
+
+ new_point = embedding + (vector_between/x)
+ get_closest_to_point(new_point, style_limit=style)
+ print('-' * 25)
+
+
+# In[17]:
+
+
+# beer_id = '388-1703' # cantillio gueueze
+# beer_id = '140-276' # SN pale ale
+beer_id = '4-59' # allagash white
+# beer_id = '64-33832' # palo salto marron
+
+info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]]
+print('Real ABV: {}'.format(info[2]))
+
+one_embed = final_data[beer_id]['embed']
+translate_to_style(one_embed, 'Belgian Quadrupel (Quad)')
+
+
+# In[22]:
+
+
+# beer_id = '388-1703' # cantillio gueueze
+# beer_id = '140-276' # SN pale ale
+beer_id = '4-59' # allagash white
+# beer_id = '64-33832' # palo salto marron
+
+info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]]
+print('Real ABV: {}'.format(info[2]))
+
+one_embed = final_data[beer_id]['embed']
+translate_to_attr(one_embed, 'Hoppy')
+
+
+# In[ ]:
+
+
+
+
+
+# In[ ]:
+
+
+
+