Browse Source

Fix ingest workflow

Matthias Vogelgesang 9 năm trước cách đây
mục cha
commit
e73e29f817
3 tập tin đã thay đổi với 66 bổ sung21 xóa
  1. 48 19
      dm/client.py
  2. 4 0
      dm/exceptions.py
  3. 14 2
      dm/models.py

+ 48 - 19
dm/client.py

@@ -3,15 +3,17 @@ import requests
 import logging
 from xml.etree import ElementTree
 from requests_oauthlib import OAuth1
-from .models import Object, Download, Credentials, Ingest, Investigation
-from .exceptions import Timeout, PermissionError, NotFoundError, ArbitraryError
+from .models import (Object, Download, Credentials, Ingest, IngestStatus,
+        Investigation, User, AccessPoint)
+from .exceptions import (Timeout, PermissionError, NotFoundError,
+        ArbitraryError, InvalidInput)
 
 
 BASE_URL = 'http://kitdm.anka.kit.edu:8080/KITDM/rest'
 KEY = os.getenv('ASTOR_RECO_KEY') or 'secret'
 SECRET = os.getenv('ASTOR_RECO_SECRET') or 'secret'
-TOKEN = os.getenv('ASTOR_RECO_TOKEN') or 'xQ6dTveTtBpFmmZ7'
-TOKEN_SECRET = os.getenv('ASTOR_RECO_TOKEN_SECRET') or 'T7rCUS2lFXNAXv6e'
+TOKEN = os.getenv('ASTOR_RECO_TOKEN')
+TOKEN_SECRET = os.getenv('ASTOR_RECO_TOKEN_SECRET')
 
 
 def check_response(response):
@@ -19,17 +21,21 @@ def check_response(response):
         raise PermissionError()
     elif response.status_code == 404:
         raise NotFoundError()
+    elif response.status_code == 412:
+        raise InvalidInput()
     elif response.status_code != 200:
+        print response.text
         raise ArbitraryError(response)
 
 
 class Client(object):
     def __init__(self, base_url=BASE_URL, key=KEY, secret=SECRET,
                  token=TOKEN, token_secret=TOKEN_SECRET):
-        self.auth = OAuth1(KEY, client_secret=SECRET,
-                           resource_owner_key=TOKEN,
-                           resource_owner_secret=TOKEN_SECRET,
-                           signature_method='PLAINTEXT')
+        self.session = requests.Session()
+        self.session.auth = OAuth1(key, client_secret=secret,
+                                   resource_owner_key=token,
+                                   resource_owner_secret=token_secret,
+                                   signature_method='PLAINTEXT')
         self.base_url = base_url
         self.log = logging.getLogger(__name__)
 
@@ -72,15 +78,20 @@ class Client(object):
     def get_objects(self, limit=None):
         return [self.get_object(oid) for oid in self.get_object_ids(limit)]
 
-    def create_object(self, investigation_id, label=None, uploader_id=None, note=None):
+    def create_object(self, investigation_id, uploader_id, label=None, note=None):
         url = self.url('basemetadata/investigations', investigation_id, 'digitalObjects')
-        params = dict(label=label, uploader_id=uploader_id, note=note)
+        return Object(self.post(url, label=label, uploaderId=uploader_id, note=note))
 
-        try:
-            response = requests.post(url, data=params, timeout=5)
-            check_response(response)
-        except requests.exceptions.Timeout:
-            raise Timeout()
+    def create_ingest(self, object_id, accesspoint_uuid):
+        url = self.url('staging/ingests')
+        return Ingest(self.post(url, objectId=object_id, accessPoint=accesspoint_uuid))
+
+    def set_ingest_status(self, oid, status, error_message=""):
+        if not isinstance(status, IngestStatus):
+            raise TypeError
+
+        params = dict(status=status, errorMessage=error_message)
+        self.session.put(self.url('staging/ingests', oid), data=params)
 
     def get_organization(self, oid):
         return Organization(self.get_response(self.url('dataorganization/organization', oid)))
@@ -91,19 +102,37 @@ class Client(object):
     def get_accesspoint(self, oid):
         return AccessPoint(self.get_response(self.url('staging/accesspoints', oid)))
 
-    def get_accesspoints(self):
-        return [self.get_accesspoint(oid) for oid in self.get_accesspoint_ids()]
+    def get_accesspoints(self, scheme=None):
+        aps = (self.get_accesspoint(oid) for oid in self.get_accesspoint_ids())
+
+        if scheme:
+            return [ap for ap in aps if ap.url.startswith(scheme) ]
+
+        return list(aps)
 
     def get_collection(self, url, id_name, limit=None):
         params = {'results': limit} if limit else {}
         root = self.get_response(self.url(url), **params)
         return [int(e.text) for e in root.findall('./entities/entity/{}'.format(id_name))]
 
+    def get_user(self, oid):
+        return User(self.get_response(self.url('usergroup/users', oid)), self)
+
     def get_response(self, url, **params):
-        self.log.debug("Accessing {}".format(url))
+        self.log.debug("GET {}".format(url))
+
+        try:
+            response = self.session.get(url, params=params, timeout=5)
+            check_response(response)
+            return ElementTree.fromstring(response.text)
+        except requests.exceptions.Timeout:
+            raise Timeout()
+
+    def post(self, url, **params):
+        self.log.debug("POST {}".format(url))
 
         try:
-            response = requests.get(url, auth=self.auth, params=params, timeout=5)
+            response = self.session.post(url, data=params, timeout=5)
             check_response(response)
             return ElementTree.fromstring(response.text)
         except requests.exceptions.Timeout:

+ 4 - 0
dm/exceptions.py

@@ -11,6 +11,10 @@ class NotFoundError(Error):
     pass
 
 
+class InvalidInput(Error):
+    pass
+
+
 class ArbitraryError(Error):
     def __init__(self, response):
         text = response.text[:46] + ' ...' if len(response.text) > 50 else response.text

+ 14 - 2
dm/models.py

@@ -52,7 +52,7 @@ class IngestStatus(Enum):
 class Ingest(EntityContainer):
     def __init__(self, root):
         super(Ingest, self).__init__(root)
-        self.id = self.entities.get('id')
+        self.id = int(self.entities.get('id'))
         self.staging_url = self.entities.get('stagingUrl')
         self.object_uuid = self.entities.get('digitalObjectUuid')
         self.owner_uuid = self.entities.get('owernUuid')
@@ -70,4 +70,16 @@ class AccessPoint(EntityContainer):
     def __init__(self, root):
         super(AccessPoint, self).__init__(root)
         self.url = self.entities['remoteBaseUrl']
-        self.identifier = self.entities['uniqueIdentifier']
+        self.uuid = self.entities['uniqueIdentifier']
+
+
+class User(EntityContainer):
+    def __init__(self, root, client):
+        super(User, self).__init__(root)
+        self.client = client
+        self.id = int(self.entities['userId'])
+        self.ldap_id = int(self.entities['distinguishedName'])
+
+    def get_credentials(self):
+        response = self.client.get_response(self.client.url('information/properties'), userId=self.ldap_id)
+        return Credentials(response)