Source code for deepmatch.utils

# -*- coding:utf-8 -*-
"""

Author:
    Weichen Shen, wcshenswc@163.com

"""

import json
import logging
import requests
from collections import namedtuple
from threading import Thread

try:
    from packaging.version import parse
except ImportError:
    from pip._vendor.packaging.version import parse

import tensorflow as tf

from tensorflow.keras import backend as K
from tensorflow.keras.layers import Lambda


[docs]class NegativeSampler( namedtuple('NegativeSampler', ['sampler', 'num_sampled', 'item_name', 'item_count', 'distortion'])): """ NegativeSampler Args: sampler: sampler name,['inbatch', 'uniform', 'frequency' 'adaptive',] . num_sampled: negative samples number per one positive sample. item_name: pkey of item features . item_count: global frequency of item . distortion: skew factor of the unigram probability distribution. """ __slots__ = () def __new__(cls, sampler, num_sampled, item_name, item_count=None, distortion=1.0, ): if sampler not in ['inbatch', 'uniform', 'frequency', 'adaptive']: raise ValueError(' `%s` sampler is not supported ' % sampler) if sampler in ['inbatch', 'frequency'] and item_count is None: raise ValueError(' `item_count` must not be `None` when using `inbatch` or `frequency` sampler') return super(NegativeSampler, cls).__new__(cls, sampler, num_sampled, item_name, item_count, distortion)
# def __hash__(self): # return self.sampler.__hash__()
[docs]def l2_normalize(x, axis=-1): return Lambda(lambda x: tf.nn.l2_normalize(x, axis))(x)
[docs]def inner_product(x, y, temperature=1.0): return Lambda(lambda x: tf.reduce_sum(tf.multiply(x[0], x[1])) / temperature)([x, y])
[docs]def recall_N(y_true, y_pred, N=50): return len(set(y_pred[:N]) & set(y_true)) * 1.0 / len(y_true)
[docs]def sampledsoftmaxloss(y_true, y_pred): return K.mean(y_pred)
[docs]def get_item_embedding(item_embedding, item_input_layer): return Lambda(lambda x: tf.squeeze(tf.gather(item_embedding, x), axis=1))( item_input_layer)
[docs]def check_version(version): """Return version of package on pypi.python.org using json.""" def check(version): try: url_pattern = 'https://pypi.python.org/pypi/deepmatch/json' req = requests.get(url_pattern) latest_version = parse('0') version = parse(version) if req.status_code == requests.codes.ok: j = json.loads(req.text.encode('utf-8')) releases = j.get('releases', []) for release in releases: ver = parse(release) if ver.is_prerelease or ver.is_postrelease: continue latest_version = max(latest_version, ver) if latest_version > version: logging.warning( '\nDeepMatch version {0} detected. Your version is {1}.\nUse `pip install -U deepmatch` to upgrade.Changelog: https://github.com/shenweichen/DeepMatch/releases/tag/v{0}'.format( latest_version, version)) except: print("Please check the latest version manually on https://pypi.org/project/deepmatch/#history") return Thread(target=check, args=(version,)).start()