import re import funcy import csv import numpy as np import datetime import json from collections import defaultdict import glob from scipy.spatial.distance import cosine with open('data/styles.json') as f: styles = json.load(f) with open('data/beer_info.json', 'r') as f: beer_names = json.load(f) attr_to_styles = defaultdict(list) style_to_attrs = defaultdict(list) with open('data/sparser_beer_data.csv') as f: reader = csv.DictReader(f) for line in reader: style = line.pop('Style') for k, v in line.items(): if v == '1': attr_to_styles[k].append(style) style_to_attrs[style].append(k) # In[7]: with open('data/final_data_small.json') as f: final_data = json.load(f) beer_by_style = defaultdict(list) for beer, data in final_data.items(): try: real_style = beer_names[beer][1][0] if beer_names[beer][3] > 7: beer_by_style[real_style].append(data['embed']) except: pass style_centers = {} for style, datas in beer_by_style.items(): style_centers[style.strip()] = np.mean(datas, axis=0) style_name_to_num = {y[0]: x for x,y in styles.items()} attr_centers = {} for attr, rel_styles in attr_to_styles.items(): centers = [style_centers[x] for x in rel_styles] attr_mean = np.mean(centers, axis=0) attr_centers[attr] = attr_mean embeddings = [(x[0], x[1]['embed']) for x in final_data.items()] small_embeddings = [] small_styles = [] bad = 0 for e in embeddings: newkey = e[0] if newkey in beer_names: info= beer_names[newkey] if info[3] > 25: small_embeddings.append(e) small_styles.append(info[1][0]) else: bad += 1 def get_closest(beer_id): one_embed = final_data[beer_id]['embed'] beer_ids = [] for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[1:11]: if thing[0] in beer_names: beer_ids.append(thing[0]) print(thing[0]) print('=' * 50) return beer_ids def get_closest_to_point(one_embed, style_limit=[]): # one_embed = final_data[beer_id]['embed'] if style_limit: possible_beers = [] for style, e in zip(small_styles, small_embeddings): if style == style_limit: possible_beers.append(e) else: possible_beers = small_embeddings print(len(possible_beers)) to_return = [] for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]: # print(b, a) if thing[0] in beer_names: bn = beer_names[thing[0]] to_return.append((thing[0], bn[0], bn[1][0])) return to_return def translate_to_attr(beer_id, to_attr, amt): embedding = final_data[beer_id]['embed'] if amt < 0: back = True amt = abs(amt) else: back = False relevant_styles = attr_to_styles[to_attr] small_rel_centers = {x: y for x, y in style_centers.items() if x in relevant_styles} sorted_centers = sorted(small_rel_centers, key=lambda x: cosine(small_rel_centers[x], embedding)) closest_center = sorted_centers[0] closest_center_vector = style_centers[closest_center] vector_between = closest_center_vector - embedding print('Moving Towards/From: {}'.format(closest_center)) print('=' *10) print('\n') dist_between = np.linalg.norm(closest_center_vector-embedding) amt_dict = { 1: 4, 2: 2, 3: 1 } x = amt_dict[amt] print('Moving {}%'.format(1/float(x) * 100)) if back: new_point = embedding - (vector_between/x) else: new_point = embedding + (vector_between/x) return get_closest_to_point(new_point) def translate_to_style(beer_id, style): embedding = final_data[beer_id]['embed'] closest_center_vector = style_centers[style] vector_between = closest_center_vector - embedding # print(vector_between) print('Moving Towards/From: {}'.format(style)) print('=' *10) print('\n') dist_between = np.linalg.norm(closest_center_vector-embedding) TRANSLATION_CONSTANT = 4 # 4 is slight new_point = embedding + (vector_between/TRANSLATION_CONSTANT) return get_closest_to_point(new_point, style_limit=style) def get_drinks_like(beer_id): data = final_data[beer_id] return data['alc'] # print(styles) num_to_style = {0: 'European Export / Dortmunder', 1: 'German Bock', 2: 'Low Alcohol Beer', 3: 'American Black Ale', 4: 'German Helles', 5: 'New England IPA', 6: 'American Amber / Red Lager', 7: 'Irish Dry Stout', 8: 'Leipzig Gose', 9: 'German Maibock', 10: 'Scottish Gruit / Ancient Herbed Ale', 11: 'Belgian Strong Pale Ale', 12: 'Robust Porter ', 13: 'English Dark Mild Ale', 14: 'Belgian Lambic', 15: 'Belgian IPA', 16: 'American Pale Ale (APA)', 17: 'American Imperial Porter', 18: 'American IPA', 19: 'Belgian Gueuze', 20: 'American Wheatwine Ale', 21: 'California Common / Steam Beer', 22: 'Smoke Porter', 23: 'English Pale Mild Ale', 24: 'Rye Beer', 25: 'Russian Kvass', 26: 'German Altbier', 27: 'American Malt Liquor', 28: 'Foreign / Export Stout', 29: 'Japanese Rice Lager', 30: 'German Pilsner', 31: 'German Weizenbock', 32: 'Belgian Witbier', 33: 'English Old Ale', 34: 'American Imperial Red Ale', 35: 'Belgian Quadrupel (Quad)', 36: 'American Stout', 37: 'Belgian Faro', 38: 'Pumpkin Beer', 39: 'American Porter', 40: 'Vienna Lager', 41: 'Belgian Dark Ale', 42: 'American Brut IPA', 43: 'British Barleywine', 44: 'German Kölsch', 45: 'American Barleywine', 46: 'German Kellerbier / Zwickelbier', 47: 'Scotch Ale / Wee Heavy', 48: 'European Strong Lager', 49: 'German Kristalweizen', 50: 'Baltic Porter', 51: 'Chile Beer', 52: 'American Cream Ale', 53: '[ India Pale Ales ]', 54: 'American Imperial Pilsner', 55: 'American Imperial IPA', 56: 'English Porter', 57: 'English Sweet / Milk Stout', 58: 'American Lager', 59: 'American Imperial Stout', 60: 'Belgian Blonde Ale ', 61: 'English India Pale Ale (IPA)', 62: 'German Eisbock', 63: 'Belgian Pale Ale', 64: 'American Light Lager', 65: 'Russian Imperial Stout', 66: 'German Hefeweizen', 67: 'German Märzen / Oktoberfest', 68: 'Flanders Red Ale', 69: 'English Stout', 70: 'Belgian Dubbel', 71: 'American Blonde Ale', 72: 'American Brown Ale', 73: 'Finnish Sahti', 74: 'English Oatmeal Stout', 75: 'Fruit and Field Beer', 76: 'Belgian Tripel', 77: 'Belgian Strong Dark Ale', 78: 'American Dark Wheat Ale', 79: 'Smoke Beer', 80: 'English Extra Special / Strong Bitter (ESB)', 81: 'European Pale Lager', 82: 'American Amber / Red Ale', 83: 'Flanders Oud Bruin', 84: 'American Strong Ale', 85: 'English Brown Ale', 86: 'European Dark Lager', 87: 'French Bière de Garde', 88: 'American Pale Wheat Ale', 89: 'Munich Dunkel Lager', 90: 'German Doppelbock', 91: 'German Rauchbier', 92: 'German Roggenbier', 93: 'Scottish Ale', 94: 'German Dunkelweizen', 95: 'English Bitter', 96: 'English Strong Ale', 97: 'Winter Warmer', 98: 'Herb and Spice Beer', 99: 'American Adjunct Lager', 100: 'Belgian Fruit Lambic', 101: 'Berliner Weisse', 102: 'Irish Red Ale', 103: 'Bière de Champagne / Bière Brut', 104: 'English Pale Ale', 105: 'American Brett', 106: 'Belgian Saison', 107: 'Japanese Happoshu', 108: 'Bohemian Pilsener', 109: 'German Schwarzbier', 110: 'Braggot', 111: 'American Wild Ale'} def normalize(ret): s = sum(ret.values()) for k, v in ret.items(): ret[k] = v/s return ret def get_style_preds(beer_id): data = final_data[beer_id] local_styles = data['style'] top_5 = sorted(local_styles)[-5:] print(top_5) ret = {} for idx, score in enumerate(local_styles): # print(idx) if score in top_5: # print(style_keys[idx]) # print(idx), styles[str(idx)][0] ret[num_to_style[idx]] = score ret = normalize(ret) return [{'name': x, 'score': y} for x, y in ret.items()]