# Copyright 2020 Karlsruhe Institute of Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from flask import current_app
from flask import send_file
from flask_babel import lazy_gettext as _l
import kadi.lib.constants as const
from .core import BaseStorage
[docs]
class LocalStorage(BaseStorage):
"""Storage provider that uses the local file system.
:param root_directory: The directory the storage provider operates in. Must be an
absolute path.
:param num_dirs: (optional) Number of directories for local file paths generated by
this storage provider. Must be a minimum of ``0``.
:param dir_len: (optional) Length of each directory for local file paths generated
by this storage provider. Must be a minimum of ``1``.
:raises ValueError: If the given root directory is not suitable.
"""
def __init__(self, root_directory, num_dirs=3, dir_len=2):
super().__init__(const.STORAGE_TYPE_LOCAL, storage_name=_l("Local"))
if not os.path.isabs(root_directory):
raise ValueError(
f"Given root directory '{root_directory}' is not an absolute path."
)
self._root_directory = root_directory
self._num_dirs = max(num_dirs, 0)
self._dir_len = max(dir_len, 1)
@property
def root_directory(self):
"""Get the root directory of this storage."""
return self._root_directory
def _path_from_identifier(self, identifier):
if self._num_dirs == 0:
return identifier
if len(identifier) <= self._dir_len * self._num_dirs:
return None
dirs = [
identifier[i : i + self._dir_len]
for i in range(0, len(identifier), self._dir_len)
]
filepath = os.path.join(
*dirs[0 : self._num_dirs], identifier[self._num_dirs * self._dir_len :]
)
return filepath
def _create_filepath(self, identifier):
error_msg = (
f"Given file identifier '{identifier}' is not suitable for creating a local"
" file path."
)
if not identifier or os.sep in identifier:
raise ValueError(error_msg)
filepath = self._path_from_identifier(identifier)
if filepath is None:
raise ValueError(error_msg)
return os.path.join(self.root_directory, filepath)
def _remove_empty_parent_dirs(self, filepath):
current_dir = os.path.dirname(filepath)
while (
os.path.isdir(current_dir)
and not os.listdir(current_dir)
and not os.path.samefile(self.root_directory, current_dir)
):
try:
os.rmdir(current_dir)
current_dir = os.path.dirname(current_dir)
except OSError:
break
[docs]
def exists(self, identifier):
filepath = self._create_filepath(identifier)
return os.path.isfile(filepath)
[docs]
def get_size(self, identifier):
filepath = self._create_filepath(identifier)
return os.path.getsize(filepath)
[docs]
def get_mimetype(self, identifier):
file_size = self.get_size(identifier)
with self.open(identifier) as f:
return self._get_mimetype(f, file_size)
[docs]
def open(self, identifier, mode="rb", encoding=None):
filepath = self._create_filepath(identifier)
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# pylint: disable=consider-using-with
return open(filepath, mode=mode, encoding=encoding)
[docs]
def save(self, identifier, stream, max_size=None):
with self.open(identifier, mode="wb") as f:
return self._save(f, stream, max_size=max_size, calculate_checksum=True)
[docs]
def move(self, src_identifier, dst_identifier):
src_filepath = self._create_filepath(src_identifier)
dst_filepath = self._create_filepath(dst_identifier)
os.makedirs(os.path.dirname(dst_filepath), exist_ok=True)
os.replace(src_filepath, dst_filepath)
self._remove_empty_parent_dirs(src_filepath)
[docs]
def delete(self, identifier):
filepath = self._create_filepath(identifier)
try:
os.remove(filepath)
except FileNotFoundError:
pass
self._remove_empty_parent_dirs(filepath)
[docs]
def merge(self, identifier, identifier_list):
with self.open(identifier, mode="wb") as f:
for chunk_identifier in identifier_list:
with self.open(chunk_identifier) as f_chunk:
self._save(f, f_chunk)
[docs]
def download(self, identifier, *, filename, mimetype, as_attachment=True):
filepath = self._create_filepath(identifier)
response = send_file(
filepath,
download_name=filename,
mimetype=mimetype,
as_attachment=as_attachment,
# In production environments, the web server handles conditional/range
# responses via "X-Sendfile".
conditional=current_app.environment != const.ENV_PRODUCTION,
)
# Always return the "Accept-Ranges" header, even for regular requests.
response.headers["Accept-Ranges"] = "bytes"
return response