|
@@ -3,15 +3,17 @@ import requests
|
|
|
import logging
|
|
|
from xml.etree import ElementTree
|
|
|
from requests_oauthlib import OAuth1
|
|
|
-from .models import Object, Download, Credentials, Ingest, Investigation
|
|
|
-from .exceptions import Timeout, PermissionError, NotFoundError, ArbitraryError
|
|
|
+from .models import (Object, Download, Credentials, Ingest, IngestStatus,
|
|
|
+ Investigation, User, AccessPoint)
|
|
|
+from .exceptions import (Timeout, PermissionError, NotFoundError,
|
|
|
+ ArbitraryError, InvalidInput)
|
|
|
|
|
|
|
|
|
BASE_URL = 'http://kitdm.anka.kit.edu:8080/KITDM/rest'
|
|
|
KEY = os.getenv('ASTOR_RECO_KEY') or 'secret'
|
|
|
SECRET = os.getenv('ASTOR_RECO_SECRET') or 'secret'
|
|
|
-TOKEN = os.getenv('ASTOR_RECO_TOKEN') or 'xQ6dTveTtBpFmmZ7'
|
|
|
-TOKEN_SECRET = os.getenv('ASTOR_RECO_TOKEN_SECRET') or 'T7rCUS2lFXNAXv6e'
|
|
|
+TOKEN = os.getenv('ASTOR_RECO_TOKEN')
|
|
|
+TOKEN_SECRET = os.getenv('ASTOR_RECO_TOKEN_SECRET')
|
|
|
|
|
|
|
|
|
def check_response(response):
|
|
@@ -19,17 +21,21 @@ def check_response(response):
|
|
|
raise PermissionError()
|
|
|
elif response.status_code == 404:
|
|
|
raise NotFoundError()
|
|
|
+ elif response.status_code == 412:
|
|
|
+ raise InvalidInput()
|
|
|
elif response.status_code != 200:
|
|
|
+ print response.text
|
|
|
raise ArbitraryError(response)
|
|
|
|
|
|
|
|
|
class Client(object):
|
|
|
def __init__(self, base_url=BASE_URL, key=KEY, secret=SECRET,
|
|
|
token=TOKEN, token_secret=TOKEN_SECRET):
|
|
|
- self.auth = OAuth1(KEY, client_secret=SECRET,
|
|
|
- resource_owner_key=TOKEN,
|
|
|
- resource_owner_secret=TOKEN_SECRET,
|
|
|
- signature_method='PLAINTEXT')
|
|
|
+ self.session = requests.Session()
|
|
|
+ self.session.auth = OAuth1(key, client_secret=secret,
|
|
|
+ resource_owner_key=token,
|
|
|
+ resource_owner_secret=token_secret,
|
|
|
+ signature_method='PLAINTEXT')
|
|
|
self.base_url = base_url
|
|
|
self.log = logging.getLogger(__name__)
|
|
|
|
|
@@ -72,15 +78,20 @@ class Client(object):
|
|
|
def get_objects(self, limit=None):
|
|
|
return [self.get_object(oid) for oid in self.get_object_ids(limit)]
|
|
|
|
|
|
- def create_object(self, investigation_id, label=None, uploader_id=None, note=None):
|
|
|
+ def create_object(self, investigation_id, uploader_id, label=None, note=None):
|
|
|
url = self.url('basemetadata/investigations', investigation_id, 'digitalObjects')
|
|
|
- params = dict(label=label, uploader_id=uploader_id, note=note)
|
|
|
+ return Object(self.post(url, label=label, uploaderId=uploader_id, note=note))
|
|
|
|
|
|
- try:
|
|
|
- response = requests.post(url, data=params, timeout=5)
|
|
|
- check_response(response)
|
|
|
- except requests.exceptions.Timeout:
|
|
|
- raise Timeout()
|
|
|
+ def create_ingest(self, object_id, accesspoint_uuid):
|
|
|
+ url = self.url('staging/ingests')
|
|
|
+ return Ingest(self.post(url, objectId=object_id, accessPoint=accesspoint_uuid))
|
|
|
+
|
|
|
+ def set_ingest_status(self, oid, status, error_message=""):
|
|
|
+ if not isinstance(status, IngestStatus):
|
|
|
+ raise TypeError
|
|
|
+
|
|
|
+ params = dict(status=status, errorMessage=error_message)
|
|
|
+ self.session.put(self.url('staging/ingests', oid), data=params)
|
|
|
|
|
|
def get_organization(self, oid):
|
|
|
return Organization(self.get_response(self.url('dataorganization/organization', oid)))
|
|
@@ -91,19 +102,37 @@ class Client(object):
|
|
|
def get_accesspoint(self, oid):
|
|
|
return AccessPoint(self.get_response(self.url('staging/accesspoints', oid)))
|
|
|
|
|
|
- def get_accesspoints(self):
|
|
|
- return [self.get_accesspoint(oid) for oid in self.get_accesspoint_ids()]
|
|
|
+ def get_accesspoints(self, scheme=None):
|
|
|
+ aps = (self.get_accesspoint(oid) for oid in self.get_accesspoint_ids())
|
|
|
+
|
|
|
+ if scheme:
|
|
|
+ return [ap for ap in aps if ap.url.startswith(scheme) ]
|
|
|
+
|
|
|
+ return list(aps)
|
|
|
|
|
|
def get_collection(self, url, id_name, limit=None):
|
|
|
params = {'results': limit} if limit else {}
|
|
|
root = self.get_response(self.url(url), **params)
|
|
|
return [int(e.text) for e in root.findall('./entities/entity/{}'.format(id_name))]
|
|
|
|
|
|
+ def get_user(self, oid):
|
|
|
+ return User(self.get_response(self.url('usergroup/users', oid)), self)
|
|
|
+
|
|
|
def get_response(self, url, **params):
|
|
|
- self.log.debug("Accessing {}".format(url))
|
|
|
+ self.log.debug("GET {}".format(url))
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = self.session.get(url, params=params, timeout=5)
|
|
|
+ check_response(response)
|
|
|
+ return ElementTree.fromstring(response.text)
|
|
|
+ except requests.exceptions.Timeout:
|
|
|
+ raise Timeout()
|
|
|
+
|
|
|
+ def post(self, url, **params):
|
|
|
+ self.log.debug("POST {}".format(url))
|
|
|
|
|
|
try:
|
|
|
- response = requests.get(url, auth=self.auth, params=params, timeout=5)
|
|
|
+ response = self.session.post(url, data=params, timeout=5)
|
|
|
check_response(response)
|
|
|
return ElementTree.fromstring(response.text)
|
|
|
except requests.exceptions.Timeout:
|