Last active
June 24, 2020 03:55
-
-
Save rwalk/384b9cce2e83c1502f607b2187176789 to your computer and use it in GitHub Desktop.
Co-Occurence based recommendation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import sys | |
import warnings | |
import numpy as np | |
from scipy.sparse import load_npz, coo_matrix | |
class CooccurenceRecommender: | |
def __init__(self, U, items): | |
''' | |
A Co-occurence based recommendation engine | |
U: sparse CSR matrix n users by m items | |
items: list(dict) of item metadata where each dict contains at least the keys title, index | |
''' | |
self._U = U | |
self._items = items | |
def _item_index_lookup(self, **kwargs): | |
indexes = [] | |
for item in self._items: | |
match = False | |
for k,v in kwargs.items(): | |
# check each condition | |
field_value = item.get(k) | |
value = v.lower() | |
if field_value: | |
if type(field_value) is str: | |
if value in field_value.lower(): | |
match = True | |
else: | |
match = False | |
break | |
elif type(field_value) is list and all([type(f) is str for f in field_value]): | |
if any([value in f.lower() for f in field_value]): | |
match = True | |
else: | |
match = False | |
break | |
else: | |
raise ValueError(f"Field {k} is not queryable!") | |
else: | |
match = False | |
break | |
if match: | |
print(f"Matched item: {item}") | |
indexes.append(item["index"]) | |
if len(indexes) > 25: | |
warnings.warn("More than 25 items matched this query. Only taking first 10.") | |
return indexes | |
return indexes | |
def _build_query_vector(self, indexes): | |
# build the query vector | |
data, I, J = [], [], [] | |
for idx in indexes: | |
if idx: | |
I.append(idx) | |
J.append(0) | |
data.append(1) | |
q = coo_matrix((data, (I, J)), shape=(self._U.shape[-1], 1), dtype=np.float64).tocsr() | |
return q | |
def _score(self, q, number): | |
y = self._U.transpose().dot(self._U.dot(q)) | |
recs = [{ | |
"item": items[i], | |
"score": float(score), | |
} for i, score in zip(y.indices, y.data) | |
] | |
recs.sort(key=lambda x: x["score"], reverse=True) | |
return recs[0:number] | |
def recommend(self, number=10, **kwargs): | |
indexes = self._item_index_lookup(**kwargs) | |
q = self._build_query_vector(indexes) | |
return self._score(q, number) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Cooccurence Recommender") | |
parser.add_argument("matrix_file", help="Sparse user item matrix in npz format") | |
parser.add_argument("items_file", help="File of JSON array where each item contains at least the keys title, index") | |
args = parser.parse_args() | |
U = load_npz(args.matrix_file) | |
with open(args.items_file) as f: | |
items = json.load(f) | |
recommender = CooccurenceRecommender(U, items) | |
try: | |
while True: | |
q_string = input("Enter a query as JSON (type 'example' for help):\n") | |
if q_string.lower().strip() == "example": | |
print("Example: {\"authors\": \"Paul Bowles\", \"title\": \"Sky\"}") | |
elif len(q_string.strip()) == 0: | |
pass | |
else: | |
try: | |
query = json.loads(q_string) | |
for hit in recommender.recommend(number=5, **query): | |
print(json.dumps(hit, indent=2)) | |
except json.decoder.JSONDecodeError: | |
print("Query is not valid JSON!") | |
except KeyboardInterrupt: | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment