import re import funcy import csv import numpy as np import datetime import json from collections import defaultdict import glob from scipy.spatial.distance import cosine with open('data/styles.json') as f: styles = json.load(f) with open('data/beer_info.json', 'r') as f: beer_names = json.load(f) attr_to_styles = defaultdict(list) style_to_attrs = defaultdict(list) with open('data/sparser_beer_data.csv') as f: reader = csv.DictReader(f) for line in reader: style = line.pop('Style') for k, v in line.items(): if v == '1': attr_to_styles[k].append(style) style_to_attrs[style].append(k) # In[7]: with open('data/final_data_small.json') as f: final_data = json.load(f) failed = 0 beer_by_style = defaultdict(list) for beer, data in final_data.items(): try: a, b = beer.split('-') real_style = beer_names[b + '-' + a][1][0] if beer_names[b + '-' + a][3] > 7: beer_by_style[real_style].append(data['embed']) except: failed += 1 style_centers = {} for style, datas in beer_by_style.items(): style_centers[style.strip()] = np.mean(datas, axis=0) style_name_to_num = {y[0]: x for x,y in styles.items()} attr_centers = {} for attr, rel_styles in attr_to_styles.items(): centers = [style_centers[x] for x in rel_styles] attr_mean = np.mean(centers, axis=0) attr_centers[attr] = attr_mean embeddings = [(x[0], x[1]['embed']) for x in final_data.items()] small_embeddings = [] small_styles = [] bad = 0 for e in embeddings: b, a = e[0].split('-') newkey = a + '-' + b if newkey in beer_names: info= beer_names[newkey] if info[3] > 25: small_embeddings.append(e) small_styles.append(info[1][0]) else: bad += 1 def get_closest(beer_id): one_embed = final_data[beer_id]['embed'] for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:25]: b, a = thing[0].split('-') # print(b, a) if a + '-' + b in beer_names: print(beer_names[a + '-' + b]) print('=' * 50) def get_closest_to_point(one_embed, style_limit=[]): # one_embed = final_data[beer_id]['embed'] if style_limit: possible_beers = [] for style, e in zip(small_styles, small_embeddings): if style == style_limit: possible_beers.append(e) else: possible_beers = small_embeddings print(len(possible_beers)) for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]: b, a = thing[0].split('-') # print(b, a) if a + '-' + b in beer_names: bn = beer_names[a + '-' + b] print(bn[0], '---', bn[1][0]) print('=' * 50) # In[16]: def translate_to_attr(embedding, to_attr): if to_attr.startswith('-'): back = True to_attr = to_attr[1:] else: back = False relevant_styles = attr_to_styles[to_attr] small_rel_centers = {x: y for x, y in style_centers.items() if x in relevant_styles} sorted_centers = sorted(small_rel_centers, key=lambda x: cosine(small_rel_centers[x], embedding)) closest_center = sorted_centers[0] closest_center_vector = style_centers[closest_center] vector_between = closest_center_vector - embedding print('Moving Towards/From: {}'.format(closest_center)) print('=' *10) print('\n') dist_between = np.linalg.norm(closest_center_vector-embedding) for x in [4,2, 1]: print('Moving {}%'.format(1/float(x) * 100)) if back: new_point = embedding - (vector_between/x/5) else: new_point = embedding + (vector_between/x) get_closest_to_point(new_point) print('-' * 25) def translate_to_style(embedding, style): closest_center_vector = style_centers[style] vector_between = closest_center_vector - embedding # print(vector_between) print('Moving Towards/From: {}'.format(style)) print('=' *10) print('\n') dist_between = np.linalg.norm(closest_center_vector-embedding) for x in [4,2, 1]: print('Moving {}%'.format(1/float(x) * 100)) new_point = embedding + (vector_between/x) get_closest_to_point(new_point, style_limit=style) print('-' * 25) # In[17]: # beer_id = '388-1703' # cantillio gueueze # beer_id = '140-276' # SN pale ale beer_id = '4-59' # allagash white # beer_id = '64-33832' # palo salto marron info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] translate_to_style(one_embed, 'Belgian Quadrupel (Quad)') # In[22]: # beer_id = '388-1703' # cantillio gueueze # beer_id = '140-276' # SN pale ale beer_id = '4-59' # allagash white # beer_id = '64-33832' # palo salto marron info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] translate_to_attr(one_embed, 'Hoppy') # In[ ]: # In[ ]: