Source code for flask_pluginkit.utils

# -*- coding: utf-8 -*-
"""
flask_pluginkit.utils
~~~~~~~~~~~~~~~~~~~~~

Some tool classes and functions.

:copyright: (c) 2019 by staugur.
:license: BSD 3-Clause, see LICENSE for more details.
"""

import sys
import json
import shelve
import importlib

from re import compile
from functools import cmp_to_key
from os.path import join, abspath, isdir
from tempfile import gettempdir
from collections import deque
from time import time
from subprocess import call, check_output
from typing import List, Any, Optional, Dict

from flask import Response, jsonify
from markupsafe import Markup
from semver.version import Version

from ._compat import string_types, text_type, iteritems
from .exceptions import (
    PluginError,
    NotCallableError,
    ParamError,
    RunError,
    NotImplementedError,
)
from .version import __version__

comma_pat = compile(r"\s*,\s*")
egg_pat = compile(r"egg=([\w-]+)")


def isValidPrefix(prefix: str, allow_none: bool = False) -> bool:
    """Check if it can be used for blueprint prefix"""
    if prefix is None and allow_none is True:
        return True
    if isinstance(prefix, string_types):
        return (
            prefix.startswith("/")
            and not prefix.endswith("/")
            and "//" not in prefix
            and " " not in prefix
        )
    return False


[docs] def isValidSemver(version: str) -> bool: """Semantic version number - determines whether the version is qualified. The format is MAJOR.Minor.PATCH, more with https://semver.org """ if version and isinstance(version, string_types): return Version.is_valid(version) return False
[docs] def sortedSemver(versions: List[str], sort: str = "ASC") -> List[str]: """Semantically sort the list of version Numbers""" reverse = True if sort.upper() == "DESC" else False if versions and isinstance(versions, (list, tuple)): return sorted( versions, key=cmp_to_key(lambda v1, v2: Version.parse(v1).compare(v2)), reverse=reverse, ) else: raise TypeError("Invaild versions, a list or tuple is right.")
def is_match_version_req(appversion: Optional[str] = None): """Check if the version of flask-pluginkit meets the plugin's version requirement. :param str appversion: Match program versions using operators and grouping symbols. """ #: If None, it is assumed to be compatible with all versions by default. if not appversion: return True if not isinstance(appversion, string_types): appversion = appversion.decode("utf-8") def vermatch(check_ver): if isValidSemver(check_ver): check_ver = ">={}".format(check_ver) try: return Version.parse(__version__).match(check_ver) except ValueError: return False avs = comma_pat.split(appversion) for v in avs: if not vermatch(v): return False else: return True
[docs] class BaseStorage(object): """This is the base class for storage. The available storage classes need to inherit from :class:`~BaseStorage` and override the `get` and `set` methods, it's best to implement the remote method as well. This base class customizes the `__getitem__`, `__setitem__` and `__delitem__` methods so that the user can call it like a dict. .. versionchanged:: 3.4.1 Change :attr:`index` to :attr:`DEFAULT_INDEX` .. versionchanged:: 3.4.1 Add empty :attr:`list`, `get`, `set` method, which must be overridden. """ #: The default index, as the only key, you can override it. DEFAULT_INDEX: str = "flask_pluginkit_dat" @property def index(self): """Get the final index .. versionadded:: 3.4.1 """ return getattr(self, "COVERED_INDEX", None) or self.DEFAULT_INDEX @index.setter def index(self, _covered_index: str): """Set the covered index .. versionadded:: 3.6.0 """ self.COVERED_INDEX = _covered_index @property def list(self) -> Dict[str, Any]: raise NotImplementedError("Please override the list method") def set(self, key: str, value: Any): raise NotImplementedError("Please override the list method") def get(self, key: str) -> Any: raise NotImplementedError("Please override the list method") def remove(self, key: str) -> Any: raise NotImplementedError("Please override the list method") def __getitem__(self, key: str): if hasattr(self, "get"): return self.get(key) else: raise AttributeError("Please override the get method") def __setitem__(self, key: str, value: Any): if hasattr(self, "set"): return self.set(key, value) else: raise AttributeError("Please override the set method") def __delitem__(self, key: str): if hasattr(self, "remove"): return self.remove(key) else: return False def __str__(self): return "<%s object at %s, index is %s>" % ( self.__class__.__name__, hex(id(self)), self.index, ) __repr__ = __str__
[docs] class LocalStorage(BaseStorage): """Local file system storage based on the shelve module.""" def __init__(self, path: Optional[str] = None): self.COVERED_INDEX = path or join(gettempdir(), self.DEFAULT_INDEX) def _open(self, flag: str = "c") -> shelve.Shelf: return shelve.open( filename=abspath(self.index), flag=flag, protocol=2, writeback=False, ) @property def list(self) -> Dict[str, Any]: """list all data :returns: dict """ db = None try: db = self._open(flag="r") except Exception: return dict() else: return dict(db) finally: if db: db.close() def __ck(self, key: str) -> str: if not isinstance(key, text_type): key = key.decode("utf-8") return key
[docs] def set(self, key: str, value: Any): """Set persistent data with shelve. :param key: str: Index key :param value: All supported data types in python :raises: :returns: """ db = None try: db = self._open() db[self.__ck(key)] = value finally: if db: db.close()
[docs] def setmany(self, **mapping: Dict[str, Any]): """Set more data :param mapping: the more k=v .. versionadded:: 3.4.1 """ if mapping and isinstance(mapping, dict): db = self._open() for k, v in iteritems(mapping): db[self.__ck(k)] = v db.close()
[docs] def get(self, key: str, default: Any = None): """Get persistent data from shelve. :returns: data """ try: value = self.list[key] except KeyError: return default else: return value
def remove(self, key: str): db = self._open() del db[key] def __len__(self): return len(self.list)
[docs] class ExpiredLocalStorage(BaseStorage): """Local file system storage based on the shelve module, support exire time.""" def __init__(self, path: Optional[str] = None): self.COVERED_INDEX = path or join(gettempdir(), self.DEFAULT_INDEX)
[docs] def set(self, key: str, value: Any, ttl: int = 0): """Set persistent data with expired time. :param key: str: Index key :param value: All supported data types in python :param ttl: int: expired time in seconds, default is 0(no expired) :raises: """ if not key or not value or ttl < 0: raise ParamError("Invalid key or value or ttl") with shelve.open(self.COVERED_INDEX) as db: etime = int(time()) + ttl if ttl > 0 else 0 db[key] = {"value": value, "etime": etime}
[docs] def get(self, key: str) -> Any: """Gets the key value and automatically deletes and returns None if it has expired""" with shelve.open(self.COVERED_INDEX) as db: entry = db.get(key) if not entry: return None etime = entry["etime"] if etime == 0: return entry["value"] elif time() > etime: del db[key] # 删除过期条目 return None return entry["value"]
[docs] def remove(self, key): """Remove the key from the storage""" with shelve.open(self.COVERED_INDEX) as db: if key in db: del db[key]
@property def list(self) -> Dict[str, Any]: """list all data""" with shelve.open(self.COVERED_INDEX) as db: now = int(time()) # 过滤掉过期的键值对 valid_data = { k: v for k, v in db.items() if v["etime"] > now or v["etime"] == 0 } return {k: v["value"] for k, v in valid_data.items()}
[docs] class RedisStorage(BaseStorage): """Use redis stand-alone storage""" def __init__(self, redis_url=None, redis_connection=None): self._db = self._open(redis_url) if redis_url else redis_connection def _open(self, redis_url): try: from redis import from_url except ImportError: raise ImportError("Please install the redis module, eg: pip install redis") else: return from_url(redis_url) @property def list(self) -> Dict[str, Any]: """list redis hash data""" return {k: json.loads(v) for k, v in iteritems(self._db.hgetall(self.index))}
[docs] def set(self, key: str, value: Any): """set key data""" return self._db.hset(self.index, key, json.dumps(value))
[docs] def setmany(self, **mapping: Dict[str, Any]): """Set more data :param mapping: the more k=v .. versionadded:: 3.4.1 """ if mapping and isinstance(mapping, dict): mapping = {k: json.dumps(v) for k, v in iteritems(mapping)} return self._db.hmset(self.index, mapping)
[docs] def get(self, key: str, default: Any = None) -> Any: """get key original data from redis""" v = self._db.hget(self.index, key) if v: if not isinstance(v, text_type): v = v.decode("utf-8") return json.loads(v) return default
[docs] def remove(self, key: str): """delete key from redis""" return self._db.hdel(self.index, key)
def __len__(self): return self._db.hlen(self.index)
[docs] class JsonResponse(Response): """In response to a return type that cannot be processed. If it is a dict, return json. .. versionadded:: 3.4.0 """
[docs] @classmethod def force_type(cls, rv, environ=None): if isinstance(rv, dict): rv = jsonify(rv) return super(JsonResponse, cls).force_type(rv, environ)
class Attribution(dict): """A dict that allows for object-like property access syntax.""" def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError(name)
[docs] class DcpManager(object): def __init__(self): self._listeners = {} @property def list(self): return self._listeners
[docs] def push(self, event, callback, position="right"): """Connect a dcp, push a function. :param event: a unique identifier name for dcp. :param callback: corresponding to the event to perform a function. :param position: the position of the insertion function, right(default) or left. The default right is inserted at the end of the event, and left is inserted into the event first. :raises PluginError: the param event or position error :raises NotCallableError: the param callback is not callable .. versionadded:: 3.2.0 """ if event and isinstance(event, string_types): if not callable(callback): raise NotCallableError("The event %s cannot be called" % event) if position not in ("left", "right", "after", "before"): raise PluginError("Invalid position") if event not in self._listeners: self._listeners[event] = deque([callback]) elif position in ("left", "before"): self._listeners[event].appendleft(callback) else: self._listeners[event].append(callback) else: raise PluginError("Invalid event")
[docs] def remove(self, event, callback): """Remove a callback again.""" try: self._listeners[event].remove(callback) except (KeyError, ValueError): return False else: return True
[docs] def emit(self, event, *args, **kwargs): """Emits events for the template context. :returns: strings with :class:`~flask.Markup` """ results = [] funcs = self._listeners.get(event) or [] for f in funcs: rv = f(*args, **kwargs) if isinstance(rv, (list, tuple)): rv = "".join(rv) if rv: if not isinstance(rv, text_type): rv = rv.decode("utf-8") results.append(rv) return Markup("".join(results))
def allowed_uploaded_plugin_suffix(filename: str) -> bool: """Check suffix for uploaded filename .. versionadded:: 3.3.0 """ allow_suffix = [".tar.gz", ".tgz", ".zip"] if isinstance(filename, string_types): for suffix in allow_suffix: if filename.endswith(suffix): return True return False def check_url(addr: str) -> bool: """Check whether UrlAddr is in a valid format, for example:: http://ip:port https://abc.com .. versionadded:: 3.3.0 """ from re import compile, IGNORECASE regex = compile( r"^(?:http)s?://" r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+" r"(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" r"localhost|" r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" r"(?::\d+)?" r"(?:/?|[/?]\S+)$", IGNORECASE, ) if addr and isinstance(addr, string_types): if regex.match(addr): return True return False
[docs] def is_venv() -> bool: """Determine whether the current environment is under Virtualenv or Venv""" return hasattr(sys, "real_prefix") or ( hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix )
[docs] def pip_install( pkg: str, target_dir: str = "", index: str = "", upgrade: bool = False, quiet: bool = False, ) -> bool: """Use pip to install modules to the specified directory or default user home. :param str pkg: Package name, such as `flask`. :param str target_dir: Install to specified directory. :param str index: The index URL of the package repository, such as `https://pypi.org/simple`. :param bool upgrade: Whether to upgrade the package if it is already installed. :param bool quiet: Whether to suppress output. :raises ParamError: If the package name is invalid. :returns: True if the installation was successful, False otherwise. :rtype: bool """ if not pkg or not isinstance(pkg, string_types): raise ParamError("Invalid package name") cmd = [ sys.executable, "-m", "pip", "install", "--progress-bar", "off", "--timeout", "10", "--retries", "2", "--no-input", "--no-color", "--no-cache-dir", "--disable-pip-version-check", "--no-python-version-warning", "--no-warn-conflicts", "--no-warn-script-location", ] if target_dir: cmd.extend(("--target", target_dir)) else: if not is_venv(): cmd.append("--user") if check_url(index): cmd.extend(("-i", index)) if upgrade is True: cmd.append("--upgrade") if quiet is True: cmd.append("--quiet") cmd.append(pkg) retcode = call(cmd) return retcode == 0
[docs] def pip_list(target_dir: str = "") -> Dict[str, str]: """Get the result of pip list. :param str target_dir: Install to specified directory. :raises RunError: If the pip command fails to execute. :returns: {package_name:version} :rtype: Dict[str, str] """ cmd = [ sys.executable, "-m", "pip", "list", "--format", "json", "--disable-pip-version-check", "--no-python-version-warning", "--no-color", ] try: data: List[Dict[str, str]] = json.loads(check_output(cmd)) except Exception as e: raise RunError("Failed to get pip list: %s" % str(e)) if target_dir and isdir(target_dir): cmd.extend(("--path", target_dir)) try: data.extend(json.loads(check_output(cmd))) except Exception as e: raise RunError("Failed to get pip list with target dir: %s" % str(e)) return {n["name"]: n["version"] for n in data}
[docs] def pip_show(pkg: str, target_dir: str = "") -> Optional[str]: """Query package version. :param str pkg: Package name, such as `flask`. :param str target_dir: Install to specified directory. :raises ParamError: If the package name is invalid. :raises RunError: If the pip command fails to execute. :returns: The version of the package if found, otherwise None. :rtype: Optional[str] """ if not pkg or not isinstance(pkg, string_types): raise ParamError("Invalid package name") ret = pip_list(target_dir=target_dir) if pkg in ret: return ret[pkg]
def get_module_path(module_name:str)-> Optional[str]: """Get the path of the specified module""" try: module = importlib.import_module(module_name) # check __file__ attribute (regular module) if hasattr(module, "__file__") and module.__file__: return abspath(module.__file__) # check if it has __path__ attribute (package) elif hasattr(module, "__path__") and module.__path__: return abspath(module.__path__[0]) else: return None except ImportError: return None