client.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. import requests
  3. import logging
  4. from xml.etree import ElementTree
  5. from requests_oauthlib import OAuth1
  6. from .models import (Object, Download, Ingest, Investigation, User,
  7. AccessPoint)
  8. from .exceptions import (Timeout, PermissionError, NotFoundError,
  9. ArbitraryError, InvalidInput)
  10. BASE_URL = 'http://kitdm.anka.kit.edu:8080/KITDM/rest'
  11. KEY = os.getenv('ASTOR_RECO_KEY') or 'secret'
  12. SECRET = os.getenv('ASTOR_RECO_SECRET') or 'secret'
  13. TOKEN = os.getenv('ASTOR_RECO_TOKEN')
  14. TOKEN_SECRET = os.getenv('ASTOR_RECO_TOKEN_SECRET')
  15. def check_response(response):
  16. if response.status_code == 403:
  17. raise PermissionError()
  18. elif response.status_code == 404:
  19. raise NotFoundError()
  20. elif response.status_code == 412:
  21. raise InvalidInput()
  22. elif response.status_code != 200:
  23. raise ArbitraryError(response)
  24. class Client(object):
  25. def __init__(self, base_url=BASE_URL, key=KEY, secret=SECRET,
  26. token=TOKEN, token_secret=TOKEN_SECRET):
  27. self.session = requests.Session()
  28. self.session.auth = OAuth1(key, client_secret=secret,
  29. resource_owner_key=token,
  30. resource_owner_secret=token_secret,
  31. signature_method='PLAINTEXT')
  32. self.base_url = base_url
  33. self.log = logging.getLogger(__name__)
  34. def url(self, *args):
  35. return '/'.join([self.base_url] + [str(arg) for arg in args])
  36. def get_ingest_ids(self, limit=None):
  37. return self.get_collection('staging/ingests', 'id', limit)
  38. def get_ingests(self, limit=-1):
  39. return [Ingest(self, oid) for oid in self.get_ingest_ids(limit)]
  40. def get_investigation_ids(self, limit=None):
  41. return self.get_collection('basemetadata/investigations', 'investigationId', limit)
  42. def get_investigations(self, limit=-1):
  43. return [Investigation(self, oid) for oid in self.get_investigation_ids(limit)]
  44. def get_object_ids(self, limit=None, since=None):
  45. return self.get_collection('basemetadata/digitalObjects', 'baseId', limit)
  46. def get_objects(self, limit=-1, predicate=lambda o: True):
  47. objects = (Object(self, oid) for oid in self.get_object_ids(limit))
  48. return [o for o in objects if predicate(o)]
  49. def create_object(self, investigation_id, uploader, label=None, note=None):
  50. url = self.url('basemetadata/investigations', investigation_id, 'digitalObjects')
  51. return Object(self, -1, root=self.post(url, label=label, uploaderId=uploader.id, note=note))
  52. def delete_object(self, obj):
  53. self.delete(self.url('basemetadata/digitalObjects', obj.id))
  54. def get_download_ids(self, limit=None):
  55. return self.get_collection('staging/downloads', 'id', limit)
  56. def get_downloads(self, limit=-1):
  57. return [Download(self, oid) for oid in self.get_download_ids(limit)]
  58. def delete_ingest(self, ingest):
  59. self.delete(self.url('staging/ingests', ingest.id))
  60. def get_accesspoint_ids(self):
  61. return self.get_collection('staging/accesspoints', 'id')
  62. def get_accesspoints(self, scheme=None):
  63. aps = (AccessPoint(self, oid) for oid in self.get_accesspoint_ids())
  64. if scheme:
  65. return [ap for ap in aps if ap.url.startswith(scheme)]
  66. return list(aps)
  67. def get_collection(self, url, id_name, limit=None):
  68. params = {'results': limit} if limit else {}
  69. root = self.get(self.url(url), **params)
  70. return [int(e.text) for e in root.findall('./entities/entity/{}'.format(id_name))]
  71. def get_user(self, oid):
  72. return User(self, oid)
  73. def delete(self, url):
  74. self.log.debug("DELETE {}".format(url))
  75. try:
  76. response = self.session.delete(url, timeout=5)
  77. check_response(response)
  78. except requests.exceptions.Timeout:
  79. raise Timeout
  80. def get(self, url, **params):
  81. self.log.debug("GET {}".format(url))
  82. try:
  83. response = self.session.get(url, params=params, timeout=5)
  84. check_response(response)
  85. return ElementTree.fromstring(response.text)
  86. except requests.exceptions.Timeout:
  87. raise Timeout()
  88. def post(self, url, **params):
  89. self.log.debug("POST {}".format(url))
  90. try:
  91. response = self.session.post(url, data=params, timeout=5)
  92. check_response(response)
  93. return ElementTree.fromstring(response.text)
  94. except requests.exceptions.Timeout:
  95. raise Timeout()