This is an automated email from the ASF dual-hosted git repository.
ztang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 04f48c5 SUBMARINE-526. Use yapf to format Python code
04f48c5 is described below
commit 04f48c528004fb9b4b661758c5c410f12e68efec
Author: pingsutw <[email protected]>
AuthorDate: Mon Jun 15 16:32:57 2020 +0800
SUBMARINE-526. Use yapf to format Python code
### What is this PR for?
Use yapf to auto-format Python code
It's useful to format code that generated by swagger and help developer to
fix code style issues
### What type of PR is it?
[Improvement]
### Todos
* [ ] - Task
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-526
### How should this be tested?
https://travis-ci.org/github/pingsutw/hadoop-submarine/builds/698787450
### Screenshots (if appropriate)
### Questions:
* Does the licenses files need update? No
* Is there breaking changes for older versions? No
* Does this needs documentation? No
Author: pingsutw <[email protected]>
Closes #312 from pingsutw/SUBMARINE-526 and squashes the following commits:
ea83a0a [pingsutw] SUBMARINE-526. Use yapf to format Python code
---
docs/submarine-sdk/pysubmarine/README.md | 31 ++--
pom.xml | 1 +
submarine-sdk/pysubmarine/.style.yapf | 4 +
.../github-actions/{lint.sh => auto-format.sh} | 8 +-
.../github-actions/lint-requirements.txt | 4 +-
submarine-sdk/pysubmarine/github-actions/lint.sh | 18 ++-
submarine-sdk/pysubmarine/submarine/__init__.py | 25 +--
.../submarine/entities/_submarine_object.py | 16 +-
submarine-sdk/pysubmarine/submarine/exceptions.py | 6 +-
.../pysubmarine/submarine/job/__init__.py | 2 +-
.../pysubmarine/submarine/job/api/__init__.py | 4 +-
.../pysubmarine/submarine/job/api/jobs_api.py | 111 +++++++-------
.../pysubmarine/submarine/job/api_client.py | 169 ++++++++++++---------
.../pysubmarine/submarine/job/configuration.py | 12 +-
.../submarine/job/models/job_library_spec.py | 25 +--
.../pysubmarine/submarine/job/models/job_spec.py | 25 +--
.../submarine/job/models/job_task_spec.py | 29 ++--
.../submarine/job/models/json_response.py | 24 +--
submarine-sdk/pysubmarine/submarine/job/rest.py | 141 +++++++++++------
.../submarine/ml/pytorch/input/libsvm_dataset.py | 31 ++--
.../submarine/ml/pytorch/layers/core.py | 55 ++++---
.../ml/pytorch/model/base_pytorch_model.py | 60 +++-----
.../submarine/ml/pytorch/model/ctr/deepfm.py | 35 ++---
.../pysubmarine/submarine/ml/pytorch/parameters.py | 1 -
.../pysubmarine/submarine/ml/pytorch/registries.py | 4 +-
.../submarine/ml/tensorflow/input/input.py | 17 ++-
.../submarine/ml/tensorflow/layers/core.py | 82 +++++++---
.../submarine/ml/tensorflow/model/base_tf_model.py | 53 +++----
.../submarine/ml/tensorflow/model/deepfm.py | 10 +-
.../submarine/ml/tensorflow/model/fm.py | 7 +-
.../submarine/ml/tensorflow/model/nfm.py | 10 +-
.../submarine/ml/tensorflow/optimizer.py | 12 +-
.../submarine/ml/tensorflow/parameters.py | 1 -
.../submarine/ml/tensorflow/registries.py | 4 +-
.../submarine/store/database/db_types.py | 8 +-
.../pysubmarine/submarine/store/database/models.py | 53 ++++---
.../submarine/store/sqlalchemy_store.py | 29 ++--
.../pysubmarine/submarine/tracking/__init__.py | 4 +-
.../pysubmarine/submarine/tracking/client.py | 12 +-
.../pysubmarine/submarine/tracking/fluent.py | 11 +-
.../pysubmarine/submarine/tracking/utils.py | 1 +
.../pysubmarine/submarine/utils/__init__.py | 4 +-
submarine-sdk/pysubmarine/submarine/utils/env.py | 15 +-
.../pysubmarine/submarine/utils/fileio.py | 6 +-
.../pysubmarine/submarine/utils/rest_utils.py | 31 ++--
.../pysubmarine/submarine/utils/tf_utils.py | 59 ++++---
.../pysubmarine/submarine/utils/validation.py | 29 ++--
.../pysubmarine/tests/ml/pytorch/model/conftest.py | 11 +-
.../tests/ml/pytorch/test_loss_pytorch.py | 1 +
.../tests/ml/pytorch/test_metric_pytorch.py | 1 +
.../tests/ml/pytorch/test_optimizer_pytorch.py | 1 +
.../tests/ml/tensorflow/model/conftest.py | 3 +-
.../ml/tensorflow/model/test_base_tf_model.py | 6 +-
.../tests/ml/tensorflow/model/test_deepfm.py | 1 -
.../tests/ml/tensorflow/model/test_fm.py | 1 -
.../tests/ml/tensorflow/model/test_nfm.py | 1 -
.../tests/ml/tensorflow/test_optimizer.py | 4 +-
.../tests/store/test_sqlalchemy_store.py | 16 +-
.../pysubmarine/tests/tracking/test_tracking.py | 12 +-
.../pysubmarine/tests/tracking/test_utils.py | 16 +-
submarine-sdk/pysubmarine/tests/utils/test_env.py | 8 +-
.../pysubmarine/tests/utils/test_rest_utils.py | 28 ++--
.../pysubmarine/tests/utils/test_tf_utils.py | 27 +++-
.../pysubmarine/tests/utils/test_validation.py | 29 +++-
64 files changed, 864 insertions(+), 601 deletions(-)
diff --git a/docs/submarine-sdk/pysubmarine/README.md
b/docs/submarine-sdk/pysubmarine/README.md
index ac872a0..a401ef6 100644
--- a/docs/submarine-sdk/pysubmarine/README.md
+++ b/docs/submarine-sdk/pysubmarine/README.md
@@ -1,15 +1,15 @@
-<!---
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. See accompanying LICENSE file.
+<!---
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. See accompanying LICENSE file.
-->
# PySubmarine
@@ -23,7 +23,7 @@ tracking experiment metrics, parameters.
## Package setup
- Clone repository
```bash
-git clone https://github.com/apache/submarine.git
+git clone https://github.com/apache/submarine.git
cd submarine/submarine-sdk/pysubmarine
```
- Install pip package
@@ -34,7 +34,10 @@ pip install .
```bash
pytest --cov=submarine -vs
```
-
+- Auto format code
+```bash
+./submarine-sdk/pysubmarine/github-actions/auto-format.sh
+```
- Run checkstyle
```bash
./submarine-sdk/pysubmarine/github-actions/lint.sh
diff --git a/pom.xml b/pom.xml
index eb4f4ed..2b4b3fb 100644
--- a/pom.xml
+++ b/pom.xml
@@ -553,6 +553,7 @@
<exclude>**/*.conf</exclude>
<exclude>**/*.libsvm</exclude>
<exclude>**/*.yaml</exclude>
+ <exclude>**/*.yapf</exclude>
<exclude>**/*.libsvm</exclude>
<exclude>**/src/main/resources/META-INF/services/org.apache.hadoop.security.SecurityInfo</exclude>
<exclude>**/src/test/resources/typicalHistFolder/job1/application123-1-1-user1-SUCCEEDED.jhist</exclude>
diff --git a/submarine-sdk/pysubmarine/.style.yapf
b/submarine-sdk/pysubmarine/.style.yapf
new file mode 100644
index 0000000..34e7202
--- /dev/null
+++ b/submarine-sdk/pysubmarine/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+based_on_style = google
+indent_width: 4
+continuation_indent_width: 4
diff --git a/submarine-sdk/pysubmarine/github-actions/lint.sh
b/submarine-sdk/pysubmarine/github-actions/auto-format.sh
similarity index 82%
copy from submarine-sdk/pysubmarine/github-actions/lint.sh
copy to submarine-sdk/pysubmarine/github-actions/auto-format.sh
index 3c403e1..1609e92 100755
--- a/submarine-sdk/pysubmarine/github-actions/lint.sh
+++ b/submarine-sdk/pysubmarine/github-actions/auto-format.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
@@ -13,14 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-#!/usr/bin/env bash
set -ex
FWDIR="$(cd "$(dirname "$0")"; pwd)"
cd "$FWDIR"
cd ..
-pycodestyle --max-line-length=100 -- submarine tests
-pylint --ignore job --msg-template="{path} ({line},{column}): [{msg_id}
{symbol}] {msg}" --rcfile=pylintrc -- submarine tests
+# Autoformat code
+yapf -i submarine/**/*.py tests/**/*.py
+# Sort imports
+isort submarine/**/*.py tests/**/*.py
set +ex
diff --git a/submarine-sdk/pysubmarine/github-actions/lint-requirements.txt
b/submarine-sdk/pysubmarine/github-actions/lint-requirements.txt
index ab7efab..38e66c9 100644
--- a/submarine-sdk/pysubmarine/github-actions/lint-requirements.txt
+++ b/submarine-sdk/pysubmarine/github-actions/lint-requirements.txt
@@ -15,4 +15,6 @@
pep8==1.7.1
pylint==2.4.4
-pycodestyle==2.5.0
\ No newline at end of file
+pycodestyle==2.5.0
+yapf==0.30.0
+isort==4.3.21
diff --git a/submarine-sdk/pysubmarine/github-actions/lint.sh
b/submarine-sdk/pysubmarine/github-actions/lint.sh
index 3c403e1..83a7fac 100755
--- a/submarine-sdk/pysubmarine/github-actions/lint.sh
+++ b/submarine-sdk/pysubmarine/github-actions/lint.sh
@@ -1,3 +1,4 @@
+#!/usr/bin/env bash
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
@@ -13,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-#!/usr/bin/env bash
set -ex
FWDIR="$(cd "$(dirname "$0")"; pwd)"
@@ -22,5 +22,21 @@ cd ..
pycodestyle --max-line-length=100 -- submarine tests
pylint --ignore job --msg-template="{path} ({line},{column}): [{msg_id}
{symbol}] {msg}" --rcfile=pylintrc -- submarine tests
+./github-actions/auto-format.sh
+
+GIT_STATUS="$(git status --porcelain)"
+GIT_DIFF="$(git diff)"
+if [ "$GIT_STATUS" ]; then
+ echo "Code is not formatted by yapf and isort. Please run
./github-actions/auto-format.sh"
+ echo "Git status is"
+ echo
"------------------------------------------------------------------"
+ echo "$GIT_STATUS"
+ echo "Git diff is"
+ echo
"------------------------------------------------------------------"
+ echo "$GIT_DIFF"
+ exit 1
+else
+ echo "Test successful"
+fi
set +ex
diff --git a/submarine-sdk/pysubmarine/submarine/__init__.py
b/submarine-sdk/pysubmarine/submarine/__init__.py
index 6985d07..5484d8b 100644
--- a/submarine-sdk/pysubmarine/submarine/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/__init__.py
@@ -13,27 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from submarine.job import ApiClient, JobLibrarySpec, JobSpec, JobTaskSpec,\
- Configuration, JobsApi
-
-import submarine.tracking.fluent
import submarine.tracking as tracking
+import submarine.tracking.fluent
+from submarine.job import (ApiClient, Configuration, JobLibrarySpec, JobsApi,
+ JobSpec, JobTaskSpec)
log_param = submarine.tracking.fluent.log_param
log_metric = submarine.tracking.fluent.log_metric
set_tracking_uri = tracking.set_tracking_uri
get_tracking_uri = tracking.get_tracking_uri
-
-__all__ = ["log_metric",
- "log_param",
- "set_tracking_uri",
- "get_tracking_uri",
- "ApiClient",
- "JobLibrarySpec",
- "JobSpec",
- "JobTaskSpec",
- "Configuration",
- "JobsApi"
- ]
+__all__ = [
+ "log_metric", "log_param", "set_tracking_uri", "get_tracking_uri",
+ "ApiClient", "JobLibrarySpec", "JobSpec", "JobTaskSpec", "Configuration",
+ "JobsApi"
+]
diff --git a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
index 29137b5..92b0f35 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/_submarine_object.py
@@ -17,6 +17,7 @@ import pprint
class _SubmarineObject:
+
def __iter__(self):
# Iterate through list of properties and yield as key -> value
for prop in self._properties():
@@ -24,11 +25,16 @@ class _SubmarineObject:
@classmethod
def _properties(cls):
- return sorted([p for p in cls.__dict__ if isinstance(getattr(cls, p),
property)])
+ return sorted(
+ [p for p in cls.__dict__ if isinstance(getattr(cls, p), property)])
@classmethod
def from_dictionary(cls, the_dict):
- filtered_dict = {key: value for key, value in the_dict.items() if key
in cls._properties()}
+ filtered_dict = {
+ key: value
+ for key, value in the_dict.items()
+ if key in cls._properties()
+ }
return cls(**filtered_dict)
def __repr__(self):
@@ -51,8 +57,10 @@ class _SubmarineObjectPrinter:
def to_string(self, obj):
if isinstance(obj, _SubmarineObject):
- return "<%s: %s>" % (get_classname(obj),
self._entity_to_string(obj))
+ return "<%s: %s>" % (get_classname(obj),
+ self._entity_to_string(obj))
return self.printer.pformat(obj)
def _entity_to_string(self, entity):
- return ", ".join(["%s=%s" % (key, self.to_string(value)) for key,
value in entity])
+ return ", ".join(
+ ["%s=%s" % (key, self.to_string(value)) for key, value in entity])
diff --git a/submarine-sdk/pysubmarine/submarine/exceptions.py
b/submarine-sdk/pysubmarine/submarine/exceptions.py
index 3d760fd..304395e 100644
--- a/submarine-sdk/pysubmarine/submarine/exceptions.py
+++ b/submarine-sdk/pysubmarine/submarine/exceptions.py
@@ -18,6 +18,7 @@ class SubmarineException(Exception):
"""
Generic exception thrown to surface failure information about
external-facing operations.
"""
+
def __init__(self, message):
"""
:param message: The message describing the error that occured.
@@ -28,9 +29,10 @@ class SubmarineException(Exception):
class RestException(SubmarineException):
"""Exception thrown on non 200-level responses from the REST API"""
+
def __init__(self, json):
error_code = json.get('error_code')
- message = "%s: %s" % (error_code,
- json['message'] if 'message' in json else
"Response: " + str(json))
+ message = "%s: %s" % (error_code, json['message'] if 'message' in json
+ else "Response: " + str(json))
super(RestException, self).__init__(message)
self.json = json
diff --git a/submarine-sdk/pysubmarine/submarine/job/__init__.py
b/submarine-sdk/pysubmarine/submarine/job/__init__.py
index 59a892d..ae4d323 100644
--- a/submarine-sdk/pysubmarine/submarine/job/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/job/__init__.py
@@ -16,7 +16,6 @@
# coding: utf-8
# flake8: noqa
-
"""
Submarine Experiment API
@@ -28,6 +27,7 @@
"""
from __future__ import absolute_import
+
# import apis into sdk package
from submarine.job.api.jobs_api import JobsApi
# import ApiClient
diff --git a/submarine-sdk/pysubmarine/submarine/job/api/__init__.py
b/submarine-sdk/pysubmarine/submarine/job/api/__init__.py
index 90aad48..952d741 100644
--- a/submarine-sdk/pysubmarine/submarine/job/api/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/job/api/__init__.py
@@ -15,7 +15,7 @@
from __future__ import absolute_import
-# flake8: noqa
-
# import apis into api package
from submarine.job.api.jobs_api import JobsApi
+
+# flake8: noqa
diff --git a/submarine-sdk/pysubmarine/submarine/job/api/jobs_api.py
b/submarine-sdk/pysubmarine/submarine/job/api/jobs_api.py
index d80f5d4..05195b9 100644
--- a/submarine-sdk/pysubmarine/submarine/job/api/jobs_api.py
+++ b/submarine-sdk/pysubmarine/submarine/job/api/jobs_api.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -92,10 +91,8 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method create_job" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method create_job" % key)
params[key] = val
del params['kwargs']
@@ -118,14 +115,16 @@ class JobsApi(object):
['application/json; charset=utf-8']) # noqa: E501
# HTTP header `Content-Type`
- header_params['Content-Type'] =
self.api_client.select_header_content_type( # noqa: E501
- ['application/yaml', 'application/json']) # noqa: E501
+ header_params[
+ 'Content-Type'] = self.api_client.select_header_content_type( #
noqa: E501
+ ['application/yaml', 'application/json']) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs', 'POST',
+ '/v1/jobs',
+ 'POST',
path_params,
query_params,
header_params,
@@ -185,16 +184,15 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method delete_job" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method delete_job" % key)
params[key] = val
del params['kwargs']
# verify the required parameter 'id' is set
- if ('id' not in params or
- params['id'] is None):
- raise ValueError("Missing the required parameter `id` when calling
`delete_job`") # noqa: E501
+ if ('id' not in params or params['id'] is None):
+ raise ValueError(
+ "Missing the required parameter `id` when calling `delete_job`"
+ ) # noqa: E501
collection_formats = {}
@@ -218,7 +216,8 @@ class JobsApi(object):
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs/{id}', 'DELETE',
+ '/v1/jobs/{id}',
+ 'DELETE',
path_params,
query_params,
header_params,
@@ -278,16 +277,15 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method get_job" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method get_job" % key)
params[key] = val
del params['kwargs']
# verify the required parameter 'id' is set
- if ('id' not in params or
- params['id'] is None):
- raise ValueError("Missing the required parameter `id` when calling
`get_job`") # noqa: E501
+ if ('id' not in params or params['id'] is None):
+ raise ValueError(
+ "Missing the required parameter `id` when calling `get_job`"
+ ) # noqa: E501
collection_formats = {}
@@ -311,7 +309,8 @@ class JobsApi(object):
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs/{id}', 'GET',
+ '/v1/jobs/{id}',
+ 'GET',
path_params,
query_params,
header_params,
@@ -371,16 +370,15 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method get_log" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method get_log" % key)
params[key] = val
del params['kwargs']
# verify the required parameter 'id' is set
- if ('id' not in params or
- params['id'] is None):
- raise ValueError("Missing the required parameter `id` when calling
`get_log`") # noqa: E501
+ if ('id' not in params or params['id'] is None):
+ raise ValueError(
+ "Missing the required parameter `id` when calling `get_log`"
+ ) # noqa: E501
collection_formats = {}
@@ -404,7 +402,8 @@ class JobsApi(object):
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs/logs/{id}', 'GET',
+ '/v1/jobs/logs/{id}',
+ 'GET',
path_params,
query_params,
header_params,
@@ -464,10 +463,8 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method list_job" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method list_job" % key)
params[key] = val
del params['kwargs']
@@ -493,7 +490,8 @@ class JobsApi(object):
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs', 'GET',
+ '/v1/jobs',
+ 'GET',
path_params,
query_params,
header_params,
@@ -553,10 +551,8 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method list_log" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method list_log" % key)
params[key] = val
del params['kwargs']
@@ -582,7 +578,8 @@ class JobsApi(object):
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs/logs', 'GET',
+ '/v1/jobs/logs',
+ 'GET',
path_params,
query_params,
header_params,
@@ -644,16 +641,15 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method patch_job" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method patch_job" % key)
params[key] = val
del params['kwargs']
# verify the required parameter 'id' is set
- if ('id' not in params or
- params['id'] is None):
- raise ValueError("Missing the required parameter `id` when calling
`patch_job`") # noqa: E501
+ if ('id' not in params or params['id'] is None):
+ raise ValueError(
+ "Missing the required parameter `id` when calling `patch_job`"
+ ) # noqa: E501
collection_formats = {}
@@ -676,14 +672,16 @@ class JobsApi(object):
['application/json; charset=utf-8']) # noqa: E501
# HTTP header `Content-Type`
- header_params['Content-Type'] =
self.api_client.select_header_content_type( # noqa: E501
- ['application/yaml', 'application/json']) # noqa: E501
+ header_params[
+ 'Content-Type'] = self.api_client.select_header_content_type( #
noqa: E501
+ ['application/yaml', 'application/json']) # noqa: E501
# Authentication setting
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs/{id}', 'PATCH',
+ '/v1/jobs/{id}',
+ 'PATCH',
path_params,
query_params,
header_params,
@@ -743,10 +741,8 @@ class JobsApi(object):
params = locals()
for key, val in six.iteritems(params['kwargs']):
if key not in all_params:
- raise TypeError(
- "Got an unexpected keyword argument '%s'"
- " to method ping" % key
- )
+ raise TypeError("Got an unexpected keyword argument '%s'"
+ " to method ping" % key)
params[key] = val
del params['kwargs']
@@ -770,7 +766,8 @@ class JobsApi(object):
auth_settings = [] # noqa: E501
return self.api_client.call_api(
- '/v1/jobs/ping', 'GET',
+ '/v1/jobs/ping',
+ 'GET',
path_params,
query_params,
header_params,
diff --git a/submarine-sdk/pysubmarine/submarine/job/api_client.py
b/submarine-sdk/pysubmarine/submarine/job/api_client.py
index 59b6743..3267628 100644
--- a/submarine-sdk/pysubmarine/submarine/job/api_client.py
+++ b/submarine-sdk/pysubmarine/submarine/job/api_client.py
@@ -28,18 +28,18 @@ from __future__ import absolute_import
import datetime
import json
import mimetypes
-from multiprocessing.pool import ThreadPool
import os
import re
import tempfile
+from multiprocessing.pool import ThreadPool
# python 2 and python 3 compatibility library
import six
from six.moves.urllib.parse import quote
-from submarine.job.configuration import Configuration
import submarine.job.models
from submarine.job import rest
+from submarine.job.configuration import Configuration
class ApiClient(object):
@@ -74,7 +74,10 @@ class ApiClient(object):
'object': object,
}
- def __init__(self, configuration=None, header_name=None, header_value=None,
+ def __init__(self,
+ configuration=None,
+ header_name=None,
+ header_value=None,
cookie=None):
if configuration is None:
configuration = Configuration()
@@ -105,12 +108,21 @@ class ApiClient(object):
def set_default_header(self, header_name, header_value):
self.default_headers[header_name] = header_value
- def __call_api(
- self, resource_path, method, path_params=None,
- query_params=None, header_params=None, body=None, post_params=None,
- files=None, response_type=None, auth_settings=None,
- _return_http_data_only=None, collection_formats=None,
- _preload_content=True, _request_timeout=None):
+ def __call_api(self,
+ resource_path,
+ method,
+ path_params=None,
+ query_params=None,
+ header_params=None,
+ body=None,
+ post_params=None,
+ files=None,
+ response_type=None,
+ auth_settings=None,
+ _return_http_data_only=None,
+ collection_formats=None,
+ _preload_content=True,
+ _request_timeout=None):
config = self.configuration
@@ -121,8 +133,8 @@ class ApiClient(object):
header_params['Cookie'] = self.cookie
if header_params:
header_params = self.sanitize_for_serialization(header_params)
- header_params = dict(self.parameters_to_tuples(header_params,
- collection_formats))
+ header_params = dict(
+ self.parameters_to_tuples(header_params, collection_formats))
# path parameters
if path_params:
@@ -133,8 +145,7 @@ class ApiClient(object):
# specified safe chars, encode everything
resource_path = resource_path.replace(
'{%s}' % k,
- quote(str(v), safe=config.safe_chars_for_path_param)
- )
+ quote(str(v), safe=config.safe_chars_for_path_param))
# query parameters
if query_params:
@@ -160,11 +171,14 @@ class ApiClient(object):
url = self.configuration.host + resource_path
# perform request and return response
- response_data = self.request(
- method, url, query_params=query_params, headers=header_params,
- post_params=post_params, body=body,
- _preload_content=_preload_content,
- _request_timeout=_request_timeout)
+ response_data = self.request(method,
+ url,
+ query_params=query_params,
+ headers=header_params,
+ post_params=post_params,
+ body=body,
+ _preload_content=_preload_content,
+ _request_timeout=_request_timeout)
self.last_response = response_data
@@ -201,11 +215,10 @@ class ApiClient(object):
elif isinstance(obj, self.PRIMITIVE_TYPES):
return obj
elif isinstance(obj, list):
- return [self.sanitize_for_serialization(sub_obj)
- for sub_obj in obj]
+ return [self.sanitize_for_serialization(sub_obj) for sub_obj in
obj]
elif isinstance(obj, tuple):
- return tuple(self.sanitize_for_serialization(sub_obj)
- for sub_obj in obj)
+ return tuple(
+ self.sanitize_for_serialization(sub_obj) for sub_obj in obj)
elif isinstance(obj, (datetime.datetime, datetime.date)):
return obj.isoformat()
@@ -217,12 +230,16 @@ class ApiClient(object):
# and attributes which value is not None.
# Convert attribute name to json key in
# model definition for request.
- obj_dict = {obj.attribute_map[attr]: getattr(obj, attr)
- for attr, _ in six.iteritems(obj.swagger_types)
- if getattr(obj, attr) is not None}
+ obj_dict = {
+ obj.attribute_map[attr]: getattr(obj, attr)
+ for attr, _ in six.iteritems(obj.swagger_types)
+ if getattr(obj, attr) is not None
+ }
- return {key: self.sanitize_for_serialization(val)
- for key, val in six.iteritems(obj_dict)}
+ return {
+ key: self.sanitize_for_serialization(val)
+ for key, val in six.iteritems(obj_dict)
+ }
def deserialize(self, response, response_type):
"""Deserializes response into an object.
@@ -260,13 +277,16 @@ class ApiClient(object):
if type(klass) == str:
if klass.startswith('list['):
sub_kls = re.match(r'list\[(.*)\]', klass).group(1)
- return [self.__deserialize(sub_data, sub_kls)
- for sub_data in data]
+ return [
+ self.__deserialize(sub_data, sub_kls) for sub_data in data
+ ]
if klass.startswith('dict('):
sub_kls = re.match(r'dict\(([^,]*), (.*)\)', klass).group(2)
- return {k: self.__deserialize(v, sub_kls)
- for k, v in six.iteritems(data)}
+ return {
+ k: self.__deserialize(v, sub_kls)
+ for k, v in six.iteritems(data)
+ }
# convert str to class
if klass in self.NATIVE_TYPES_MAPPING:
@@ -285,12 +305,22 @@ class ApiClient(object):
else:
return self.__deserialize_model(data, klass)
- def call_api(self, resource_path, method,
- path_params=None, query_params=None, header_params=None,
- body=None, post_params=None, files=None,
- response_type=None, auth_settings=None, async_req=None,
- _return_http_data_only=None, collection_formats=None,
- _preload_content=True, _request_timeout=None):
+ def call_api(self,
+ resource_path,
+ method,
+ path_params=None,
+ query_params=None,
+ header_params=None,
+ body=None,
+ post_params=None,
+ files=None,
+ response_type=None,
+ auth_settings=None,
+ async_req=None,
+ _return_http_data_only=None,
+ collection_formats=None,
+ _preload_content=True,
+ _request_timeout=None):
"""Makes the HTTP request (synchronous) and returns deserialized data.
To make an async request, set the async_req parameter.
@@ -328,25 +358,29 @@ class ApiClient(object):
then the method will return the response directly.
"""
if not async_req:
- return self.__call_api(resource_path, method,
- path_params, query_params, header_params,
- body, post_params, files,
- response_type, auth_settings,
- _return_http_data_only, collection_formats,
- _preload_content, _request_timeout)
+ return self.__call_api(resource_path, method, path_params,
+ query_params, header_params, body,
+ post_params, files, response_type,
+ auth_settings, _return_http_data_only,
+ collection_formats, _preload_content,
+ _request_timeout)
else:
- thread = self.pool.apply_async(self.__call_api, (resource_path,
- method, path_params, query_params,
- header_params, body,
- post_params, files,
- response_type, auth_settings,
- _return_http_data_only,
- collection_formats,
- _preload_content, _request_timeout))
+ thread = self.pool.apply_async(
+ self.__call_api,
+ (resource_path, method, path_params, query_params,
+ header_params, body, post_params, files, response_type,
+ auth_settings, _return_http_data_only, collection_formats,
+ _preload_content, _request_timeout))
return thread
- def request(self, method, url, query_params=None, headers=None,
- post_params=None, body=None, _preload_content=True,
+ def request(self,
+ method,
+ url,
+ query_params=None,
+ headers=None,
+ post_params=None,
+ body=None,
+ _preload_content=True,
_request_timeout=None):
"""Makes the HTTP request using RESTClient."""
if method == "GET":
@@ -401,10 +435,8 @@ class ApiClient(object):
_request_timeout=_request_timeout,
body=body)
else:
- raise ValueError(
- "http method must be `GET`, `HEAD`, `OPTIONS`,"
- " `POST`, `PATCH`, `PUT` or `DELETE`."
- )
+ raise ValueError("http method must be `GET`, `HEAD`, `OPTIONS`,"
+ " `POST`, `PATCH`, `PUT` or `DELETE`.")
def parameters_to_tuples(self, params, collection_formats):
"""Get parameters as list of tuples, formatting collections.
@@ -416,7 +448,8 @@ class ApiClient(object):
new_params = []
if collection_formats is None:
collection_formats = {}
- for k, v in six.iteritems(params) if isinstance(params, dict) else
params: # noqa: E501
+ for k, v in six.iteritems(params) if isinstance(
+ params, dict) else params: # noqa: E501
if k in collection_formats:
collection_format = collection_formats[k]
if collection_format == 'multi':
@@ -517,8 +550,7 @@ class ApiClient(object):
querys.append((auth_setting['key'], auth_setting['value']))
else:
raise ValueError(
- 'Authentication token must be in `query` or `header`'
- )
+ 'Authentication token must be in `query` or `header`')
def __deserialize_file(self, response):
"""Deserializes body to file
@@ -580,8 +612,7 @@ class ApiClient(object):
except ValueError:
raise rest.ApiException(
status=0,
- reason="Failed to parse `{0}` as date object".format(string)
- )
+ reason="Failed to parse `{0}` as date object".format(string))
def __deserialize_datatime(self, string):
"""Deserializes string to datetime.
@@ -600,10 +631,7 @@ class ApiClient(object):
raise rest.ApiException(
status=0,
reason=(
- "Failed to parse `{0}` as datetime object"
- .format(string)
- )
- )
+ "Failed to parse `{0}` as datetime object".format(string)))
def __hasattr(self, object, name):
return name in object.__class__.__dict__
@@ -616,22 +644,21 @@ class ApiClient(object):
:return: model object.
"""
- if not klass.swagger_types and not self.__hasattr(klass,
'get_real_child_model'):
+ if not klass.swagger_types and not self.__hasattr(
+ klass, 'get_real_child_model'):
return data
kwargs = {}
if klass.swagger_types is not None:
for attr, attr_type in six.iteritems(klass.swagger_types):
- if (data is not None and
- klass.attribute_map[attr] in data and
+ if (data is not None and klass.attribute_map[attr] in data and
isinstance(data, (list, dict))):
value = data[klass.attribute_map[attr]]
kwargs[attr] = self.__deserialize(value, attr_type)
instance = klass(**kwargs)
- if (isinstance(instance, dict) and
- klass.swagger_types is not None and
+ if (isinstance(instance, dict) and klass.swagger_types is not None and
isinstance(data, dict)):
for key, value in data.items():
if key not in klass.swagger_types:
diff --git a/submarine-sdk/pysubmarine/submarine/job/configuration.py
b/submarine-sdk/pysubmarine/submarine/job/configuration.py
index 105f775..f8be11a 100644
--- a/submarine-sdk/pysubmarine/submarine/job/configuration.py
+++ b/submarine-sdk/pysubmarine/submarine/job/configuration.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -31,13 +30,14 @@ import copy
import logging
import multiprocessing
import sys
-import urllib3
import six
+import urllib3
from six.moves import http_client as httplib
class TypeWithDefault(type):
+
def __init__(cls, name, bases, dct):
super(TypeWithDefault, cls).__init__(name, bases, dct)
cls._default = None
@@ -234,17 +234,15 @@ class Configuration(six.with_metaclass(TypeWithDefault,
object)):
:return: The token for basic HTTP authentication.
"""
- return urllib3.util.make_headers(
- basic_auth=self.username + ':' + self.password
- ).get('authorization')
+ return urllib3.util.make_headers(basic_auth=self.username + ':' +
+ self.password).get('authorization')
def auth_settings(self):
"""Gets Auth Settings dict for api client.
:return: The Auth Settings information dict.
"""
- return {
- }
+ return {}
def to_debug_report(self):
"""Gets the essential information for debugging.
diff --git a/submarine-sdk/pysubmarine/submarine/job/models/job_library_spec.py
b/submarine-sdk/pysubmarine/submarine/job/models/job_library_spec.py
index 9a0dc81..8d017f3 100644
--- a/submarine-sdk/pysubmarine/submarine/job/models/job_library_spec.py
+++ b/submarine-sdk/pysubmarine/submarine/job/models/job_library_spec.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -59,7 +58,12 @@ class JobLibrarySpec(object):
'env_vars': 'envVars'
}
- def __init__(self, name=None, version=None, image=None, cmd=None,
env_vars=None): # noqa: E501
+ def __init__(self,
+ name=None,
+ version=None,
+ image=None,
+ cmd=None,
+ env_vars=None): # noqa: E501
"""JobLibrarySpec - a model defined in Swagger""" # noqa: E501
self._name = None
self._version = None
@@ -190,18 +194,17 @@ class JobLibrarySpec(object):
for attr, _ in six.iteritems(self.swagger_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict()
+ if hasattr(x, "to_dict") else x, value))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict") else item,
+ value.items()))
else:
result[attr] = value
if issubclass(JobLibrarySpec, dict):
diff --git a/submarine-sdk/pysubmarine/submarine/job/models/job_spec.py
b/submarine-sdk/pysubmarine/submarine/job/models/job_spec.py
index 8da643c..f038d23 100644
--- a/submarine-sdk/pysubmarine/submarine/job/models/job_spec.py
+++ b/submarine-sdk/pysubmarine/submarine/job/models/job_spec.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -59,7 +58,12 @@ class JobSpec(object):
'projects': 'projects'
}
- def __init__(self, name=None, namespace=None, library_spec=None,
task_specs=None, projects=None): # noqa: E501
+ def __init__(self,
+ name=None,
+ namespace=None,
+ library_spec=None,
+ task_specs=None,
+ projects=None): # noqa: E501
"""JobSpec - a model defined in Swagger""" # noqa: E501
self._name = None
self._namespace = None
@@ -190,18 +194,17 @@ class JobSpec(object):
for attr, _ in six.iteritems(self.swagger_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict()
+ if hasattr(x, "to_dict") else x, value))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict") else item,
+ value.items()))
else:
result[attr] = value
if issubclass(JobSpec, dict):
diff --git a/submarine-sdk/pysubmarine/submarine/job/models/job_task_spec.py
b/submarine-sdk/pysubmarine/submarine/job/models/job_task_spec.py
index 901b3d4..2d5ac5b 100644
--- a/submarine-sdk/pysubmarine/submarine/job/models/job_task_spec.py
+++ b/submarine-sdk/pysubmarine/submarine/job/models/job_task_spec.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -67,7 +66,16 @@ class JobTaskSpec(object):
'memory': 'memory'
}
- def __init__(self, name=None, image=None, cmd=None, env_vars=None,
resources=None, replicas=None, cpu=None, gpu=None, memory=None): # noqa: E501
+ def __init__(self,
+ name=None,
+ image=None,
+ cmd=None,
+ env_vars=None,
+ resources=None,
+ replicas=None,
+ cpu=None,
+ gpu=None,
+ memory=None): # noqa: E501
"""JobTaskSpec - a model defined in Swagger""" # noqa: E501
self._name = None
self._image = None
@@ -294,18 +302,17 @@ class JobTaskSpec(object):
for attr, _ in six.iteritems(self.swagger_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict()
+ if hasattr(x, "to_dict") else x, value))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict") else item,
+ value.items()))
else:
result[attr] = value
if issubclass(JobTaskSpec, dict):
diff --git a/submarine-sdk/pysubmarine/submarine/job/models/json_response.py
b/submarine-sdk/pysubmarine/submarine/job/models/json_response.py
index cf86c70..9a3c94d 100644
--- a/submarine-sdk/pysubmarine/submarine/job/models/json_response.py
+++ b/submarine-sdk/pysubmarine/submarine/job/models/json_response.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -57,7 +56,11 @@ class JsonResponse(object):
'attributes': 'attributes'
}
- def __init__(self, code=None, success=None, result=None, attributes=None):
# noqa: E501
+ def __init__(self,
+ code=None,
+ success=None,
+ result=None,
+ attributes=None): # noqa: E501
"""JsonResponse - a model defined in Swagger""" # noqa: E501
self._code = None
self._success = None
@@ -164,18 +167,17 @@ class JsonResponse(object):
for attr, _ in six.iteritems(self.swagger_types):
value = getattr(self, attr)
if isinstance(value, list):
- result[attr] = list(map(
- lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
- value
- ))
+ result[attr] = list(
+ map(lambda x: x.to_dict()
+ if hasattr(x, "to_dict") else x, value))
elif hasattr(value, "to_dict"):
result[attr] = value.to_dict()
elif isinstance(value, dict):
- result[attr] = dict(map(
- lambda item: (item[0], item[1].to_dict())
- if hasattr(item[1], "to_dict") else item,
- value.items()
- ))
+ result[attr] = dict(
+ map(
+ lambda item: (item[0], item[1].to_dict())
+ if hasattr(item[1], "to_dict") else item,
+ value.items()))
else:
result[attr] = value
if issubclass(JsonResponse, dict):
diff --git a/submarine-sdk/pysubmarine/submarine/job/rest.py
b/submarine-sdk/pysubmarine/submarine/job/rest.py
index 0d9c379..94f0288 100644
--- a/submarine-sdk/pysubmarine/submarine/job/rest.py
+++ b/submarine-sdk/pysubmarine/submarine/job/rest.py
@@ -14,7 +14,6 @@
# limitations under the License.
# coding: utf-8
-
"""
Submarine Experiment API
@@ -43,7 +42,6 @@ try:
except ImportError:
raise ImportError('Swagger python client requires urllib3.')
-
logger = logging.getLogger(__name__)
@@ -88,7 +86,8 @@ class RESTClientObject(object):
addition_pool_args = {}
if configuration.assert_hostname is not None:
- addition_pool_args['assert_hostname'] =
configuration.assert_hostname # noqa: E501
+ addition_pool_args[
+ 'assert_hostname'] = configuration.assert_hostname # noqa:
E501
if maxsize is None:
if configuration.connection_pool_maxsize is not None:
@@ -106,8 +105,7 @@ class RESTClientObject(object):
cert_file=configuration.cert_file,
key_file=configuration.key_file,
proxy_url=configuration.proxy,
- **addition_pool_args
- )
+ **addition_pool_args)
else:
self.pool_manager = urllib3.PoolManager(
num_pools=pools_size,
@@ -116,11 +114,16 @@ class RESTClientObject(object):
ca_certs=ca_certs,
cert_file=configuration.cert_file,
key_file=configuration.key_file,
- **addition_pool_args
- )
-
- def request(self, method, url, query_params=None, headers=None,
- body=None, post_params=None, _preload_content=True,
+ **addition_pool_args)
+
+ def request(self,
+ method,
+ url,
+ query_params=None,
+ headers=None,
+ body=None,
+ post_params=None,
+ _preload_content=True,
_request_timeout=None):
"""Perform requests.
@@ -141,25 +144,26 @@ class RESTClientObject(object):
(connection, read) timeouts.
"""
method = method.upper()
- assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT',
- 'PATCH', 'OPTIONS']
+ assert method in [
+ 'GET', 'HEAD', 'DELETE', 'POST', 'PUT', 'PATCH', 'OPTIONS'
+ ]
if post_params and body:
raise ValueError(
- "body parameter cannot be used with post_params parameter."
- )
+ "body parameter cannot be used with post_params parameter.")
post_params = post_params or {}
headers = headers or {}
timeout = None
if _request_timeout:
- if isinstance(_request_timeout, (int, ) if six.PY3 else (int,
long)): # noqa: E501,F821
+ if isinstance(_request_timeout, (int,) if six.PY3 else
+ (int, long)): # noqa: E501,F821
timeout = urllib3.Timeout(total=_request_timeout)
elif (isinstance(_request_timeout, tuple) and
len(_request_timeout) == 2):
- timeout = urllib3.Timeout(
- connect=_request_timeout[0], read=_request_timeout[1])
+ timeout = urllib3.Timeout(connect=_request_timeout[0],
+ read=_request_timeout[1])
if 'Content-Type' not in headers:
headers['Content-Type'] = 'application/json'
@@ -174,14 +178,17 @@ class RESTClientObject(object):
if body is not None:
request_body = json.dumps(body)
r = self.pool_manager.request(
- method, url,
+ method,
+ url,
body=request_body,
preload_content=_preload_content,
timeout=timeout,
headers=headers)
- elif headers['Content-Type'] ==
'application/x-www-form-urlencoded': # noqa: E501
+ elif headers[
+ 'Content-Type'] ==
'application/x-www-form-urlencoded': # noqa: E501
r = self.pool_manager.request(
- method, url,
+ method,
+ url,
fields=post_params,
encode_multipart=False,
preload_content=_preload_content,
@@ -193,7 +200,8 @@ class RESTClientObject(object):
# overwritten.
del headers['Content-Type']
r = self.pool_manager.request(
- method, url,
+ method,
+ url,
fields=post_params,
encode_multipart=True,
preload_content=_preload_content,
@@ -205,7 +213,8 @@ class RESTClientObject(object):
elif isinstance(body, str):
request_body = body
r = self.pool_manager.request(
- method, url,
+ method,
+ url,
body=request_body,
preload_content=_preload_content,
timeout=timeout,
@@ -218,7 +227,8 @@ class RESTClientObject(object):
raise ApiException(status=0, reason=msg)
# For `GET`, `HEAD`
else:
- r = self.pool_manager.request(method, url,
+ r = self.pool_manager.request(method,
+ url,
fields=query_params,
preload_content=_preload_content,
timeout=timeout,
@@ -243,25 +253,42 @@ class RESTClientObject(object):
return r
- def GET(self, url, headers=None, query_params=None, _preload_content=True,
+ def GET(self,
+ url,
+ headers=None,
+ query_params=None,
+ _preload_content=True,
_request_timeout=None):
- return self.request("GET", url,
+ return self.request("GET",
+ url,
headers=headers,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
query_params=query_params)
- def HEAD(self, url, headers=None, query_params=None, _preload_content=True,
+ def HEAD(self,
+ url,
+ headers=None,
+ query_params=None,
+ _preload_content=True,
_request_timeout=None):
- return self.request("HEAD", url,
+ return self.request("HEAD",
+ url,
headers=headers,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
query_params=query_params)
- def OPTIONS(self, url, headers=None, query_params=None, post_params=None,
- body=None, _preload_content=True, _request_timeout=None):
- return self.request("OPTIONS", url,
+ def OPTIONS(self,
+ url,
+ headers=None,
+ query_params=None,
+ post_params=None,
+ body=None,
+ _preload_content=True,
+ _request_timeout=None):
+ return self.request("OPTIONS",
+ url,
headers=headers,
query_params=query_params,
post_params=post_params,
@@ -269,18 +296,31 @@ class RESTClientObject(object):
_request_timeout=_request_timeout,
body=body)
- def DELETE(self, url, headers=None, query_params=None, body=None,
- _preload_content=True, _request_timeout=None):
- return self.request("DELETE", url,
+ def DELETE(self,
+ url,
+ headers=None,
+ query_params=None,
+ body=None,
+ _preload_content=True,
+ _request_timeout=None):
+ return self.request("DELETE",
+ url,
headers=headers,
query_params=query_params,
_preload_content=_preload_content,
_request_timeout=_request_timeout,
body=body)
- def POST(self, url, headers=None, query_params=None, post_params=None,
- body=None, _preload_content=True, _request_timeout=None):
- return self.request("POST", url,
+ def POST(self,
+ url,
+ headers=None,
+ query_params=None,
+ post_params=None,
+ body=None,
+ _preload_content=True,
+ _request_timeout=None):
+ return self.request("POST",
+ url,
headers=headers,
query_params=query_params,
post_params=post_params,
@@ -288,9 +328,16 @@ class RESTClientObject(object):
_request_timeout=_request_timeout,
body=body)
- def PUT(self, url, headers=None, query_params=None, post_params=None,
- body=None, _preload_content=True, _request_timeout=None):
- return self.request("PUT", url,
+ def PUT(self,
+ url,
+ headers=None,
+ query_params=None,
+ post_params=None,
+ body=None,
+ _preload_content=True,
+ _request_timeout=None):
+ return self.request("PUT",
+ url,
headers=headers,
query_params=query_params,
post_params=post_params,
@@ -298,9 +345,16 @@ class RESTClientObject(object):
_request_timeout=_request_timeout,
body=body)
- def PATCH(self, url, headers=None, query_params=None, post_params=None,
- body=None, _preload_content=True, _request_timeout=None):
- return self.request("PATCH", url,
+ def PATCH(self,
+ url,
+ headers=None,
+ query_params=None,
+ post_params=None,
+ body=None,
+ _preload_content=True,
+ _request_timeout=None):
+ return self.request("PATCH",
+ url,
headers=headers,
query_params=query_params,
post_params=post_params,
@@ -328,8 +382,7 @@ class ApiException(Exception):
error_message = "({0})\n"\
"Reason: {1}\n".format(self.status, self.reason)
if self.headers:
- error_message += "HTTP response headers: {0}\n".format(
- self.headers)
+ error_message += "HTTP response headers:
{0}\n".format(self.headers)
if self.body:
error_message += "HTTP response body: {0}\n".format(self.body)
diff --git
a/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py
b/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py
index 07980bd..a0c8a4e 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/input/libsvm_dataset.py
@@ -13,16 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from submarine.utils.fileio import read_file
-
+import pandas as pd
import torch
-from torch.utils.data import Dataset
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
-import pandas as pd
+
+from submarine.utils.fileio import read_file
class LIBSVMDataset(Dataset):
+
def __init__(self, path):
self.data, self.label = self.preprocess_data(read_file(path))
@@ -33,6 +33,7 @@ class LIBSVMDataset(Dataset):
return len(self.data)
def preprocess_data(self, stream):
+
def _convert_line(line):
feat_ids = []
feat_vals = []
@@ -52,21 +53,19 @@ class LIBSVMDataset(Dataset):
def collate_fn(self, batch):
data, label = tuple(zip(*batch))
_, feat_val = tuple(zip(*data))
- return (
- torch.stack(feat_val, dim=0).type(torch.long),
- torch.as_tensor(label, dtype=torch.float32).unsqueeze(dim=-1)
- )
+ return (torch.stack(feat_val, dim=0).type(torch.long),
+ torch.as_tensor(label, dtype=torch.float32).unsqueeze(dim=-1))
def libsvm_input_fn(filepath, batch_size=256, num_threads=1, **kwargs):
+
def _input_fn():
dataset = LIBSVMDataset(filepath)
sampler = DistributedSampler(dataset)
- return DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- sampler=sampler,
- num_workers=num_threads,
- collate_fn=dataset.collate_fn
- )
+ return DataLoader(dataset=dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_threads,
+ collate_fn=dataset.collate_fn)
+
return _input_fn
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
index d0452e0..6fff591 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/layers/core.py
@@ -13,24 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from itertools import accumulate
+
import torch
from torch import nn
-from itertools import accumulate
-
class FieldLinear(nn.Module):
+
def __init__(self, field_dims, out_features):
"""
:param field_dims: List of dimensions of each field.
:param out_features: The number of output features.
"""
super().__init__()
- self.weight = nn.Embedding(num_embeddings=sum(
- field_dims), embedding_dim=out_features)
+ self.weight = nn.Embedding(num_embeddings=sum(field_dims),
+ embedding_dim=out_features)
self.bias = nn.Parameter(torch.zeros((out_features,)))
- self.register_buffer('offset',
- torch.as_tensor([0,
*accumulate(field_dims)][:-1], dtype=torch.long))
+ self.register_buffer(
+ 'offset',
+ torch.as_tensor([0, *accumulate(field_dims)][:-1],
+ dtype=torch.long))
def forward(self, x):
"""
@@ -40,47 +43,51 @@ class FieldLinear(nn.Module):
class FieldEmbedding(nn.Module):
+
def __init__(self, field_dims, embedding_dim):
super().__init__()
- self.weight = nn.Embedding(num_embeddings=sum(
- field_dims), embedding_dim=embedding_dim)
- self.register_buffer('offset', torch.as_tensor(
- [0, *accumulate(field_dims)][:-1], dtype=torch.long))
+ self.weight = nn.Embedding(num_embeddings=sum(field_dims),
+ embedding_dim=embedding_dim)
+ self.register_buffer(
+ 'offset',
+ torch.as_tensor([0, *accumulate(field_dims)][:-1],
+ dtype=torch.long))
def forward(self, x):
"""
:param x: torch.LongTensor (batch_size, num_fields)
"""
- return self.weight(x + self.offset) # (batch_size, num_fields,
embedding_dim)
+ return self.weight(
+ x + self.offset) # (batch_size, num_fields, embedding_dim)
class PairwiseInteraction(nn.Module):
+
def forward(self, x):
"""
:param x: torch.Tensor (batch_size, num_fields, embedding_dim)
"""
- square_of_sum = torch.square(
- torch.sum(x, dim=1)) # (batch_size, embedding_dim)
+ square_of_sum = torch.square(torch.sum(
+ x, dim=1)) # (batch_size, embedding_dim)
# (batch_size, embedding_dim)
sum_of_square = torch.sum(torch.square(x), dim=1)
- return 0.5 * torch.sum(square_of_sum - sum_of_square, dim=1,
+ return 0.5 * torch.sum(square_of_sum - sum_of_square,
+ dim=1,
keepdim=True) # (batch_size, 1)
class DNN(nn.Module):
+
def __init__(self, in_features, out_features, hidden_units, dropout_rates):
super().__init__()
- *layers, out_layer = list(zip([in_features, *
- hidden_units], [*hidden_units,
out_features]))
+ *layers, out_layer = list(
+ zip([in_features, *hidden_units], [*hidden_units, out_features]))
self.net = nn.Sequential(
- *(nn.Sequential(
- nn.Linear(in_features=i, out_features=o),
- nn.BatchNorm1d(num_features=o),
- nn.ReLU(),
- nn.Dropout(p=p)
- ) for (i, o), p in zip(layers, dropout_rates)),
- nn.Linear(*out_layer)
- )
+ *(nn.Sequential(nn.Linear(in_features=i, out_features=o),
+ nn.BatchNorm1d(num_features=o), nn.ReLU(),
+ nn.Dropout(p=p))
+ for (i, o), p in zip(layers, dropout_rates)),
+ nn.Linear(*out_layer))
def forward(self, x):
"""
diff --git
a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py
b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py
index 7e24672..b862a0e 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py
@@ -13,24 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from abc import ABC
+import io
import logging
+import os
+from abc import ABC
-from submarine.ml.pytorch.metric import get_metric_fn
-from submarine.ml.pytorch.registries import input_fn_registry
-from submarine.utils.fileio import write_file
-from submarine.ml.abstract_model import AbstractModel
import torch
from torch import distributed
from torch.nn.parallel import DistributedDataParallel
-import os
-import io
-
-from submarine.ml.pytorch.optimizer import get_optimizer
+from submarine.ml.abstract_model import AbstractModel
from submarine.ml.pytorch.loss import get_loss_fn
+from submarine.ml.pytorch.metric import get_metric_fn
+from submarine.ml.pytorch.optimizer import get_optimizer
from submarine.ml.pytorch.parameters import default_parameters
-from submarine.utils.env import get_from_registry, get_from_json,
get_from_dicts
+from submarine.ml.pytorch.registries import input_fn_registry
+from submarine.utils.env import (get_from_dicts, get_from_json,
+ get_from_registry)
+from submarine.utils.fileio import write_file
from submarine.utils.pytorch_utils import get_device
logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=W0221
class BasePyTorchModel(AbstractModel, ABC):
+
def __init__(self, params=None, json_path=None):
super().__init__()
self.params = get_from_dicts(params, default_parameters)
@@ -51,8 +52,7 @@ class BasePyTorchModel(AbstractModel, ABC):
self.model_fn(self.params).to(get_device(self.params)))
self.optimizer = get_optimizer(key=self.params['optimizer']['name'])(
params=self.model.parameters(),
- **self.params['optimizer']['kwargs']
- )
+ **self.params['optimizer']['kwargs'])
self.loss = get_loss_fn(key=self.params['loss']['name'])(
**self.params['loss']['kwargs'])
self.metric = get_metric_fn(key=self.params['output']['metric'])
@@ -62,8 +62,7 @@ class BasePyTorchModel(AbstractModel, ABC):
backend=os.environ.get('backend', distributed.Backend.GLOO),
init_method=os.environ.get('INIT_METHOD', 'tcp://127.0.0.1:23456'),
world_size=int(os.environ.get('WORLD', 1)),
- rank=int(os.environ.get('RANK', 0))
- )
+ rank=int(os.environ.get('RANK', 0)))
def __del__(self):
distributed.destroy_process_group()
@@ -81,8 +80,7 @@ class BasePyTorchModel(AbstractModel, ABC):
outputs = []
targets = []
- valid_loader = get_from_registry(
- self.input_type, input_fn_registry)(
+ valid_loader = get_from_registry(self.input_type, input_fn_registry)(
filepath=self.params['input']['valid_data'],
**self.params['training'])()
@@ -96,14 +94,12 @@ class BasePyTorchModel(AbstractModel, ABC):
return self.metric(
torch.cat(targets, dim=0).cpu().numpy(),
- torch.cat(outputs, dim=0).cpu().numpy()
- )
+ torch.cat(outputs, dim=0).cpu().numpy())
def predict(self):
outputs = []
- test_loader = get_from_registry(
- self.input_type, input_fn_registry)(
+ test_loader = get_from_registry(self.input_type, input_fn_registry)(
filepath=self.params['input']['test_data'],
**self.params['training'])()
@@ -123,8 +119,7 @@ class BasePyTorchModel(AbstractModel, ABC):
# The line "if eval_score > best_eval_score:"
# should be replaced by a indicator function.
best_eval_score = 0.0
- train_loader = get_from_registry(
- self.input_type, input_fn_registry)(
+ train_loader = get_from_registry(self.input_type, input_fn_registry)(
filepath=self.params['input']['train_data'],
**self.params['training'])()
@@ -144,25 +139,18 @@ class BasePyTorchModel(AbstractModel, ABC):
{
'model': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict()
- }, buffer
- )
- write_file(
- buffer,
- path=os.path.join(
- self.params['output']['save_model_dir'], 'ckpt.pkl')
- )
+ }, buffer)
+ write_file(buffer,
+ path=os.path.join(
+ self.params['output']['save_model_dir'],
'ckpt.pkl'))
def model_fn(self, params):
seed = params["training"]["seed"]
torch.manual_seed(seed)
def _sanity_check(self):
- assert 'input' in self.params, (
- 'Does not define any input parameters'
- )
+ assert 'input' in self.params, ('Does not define any input parameters')
assert 'type' in self.params['input'], (
- 'Does not define any input type'
- )
+ 'Does not define any input type')
assert 'output' in self.params, (
- 'Does not define any output parameters'
- )
+ 'Does not define any output parameters')
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py
b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py
index ee64e2d..6c955d7 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/ctr/deepfm.py
@@ -13,44 +13,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from submarine.ml.pytorch.layers.core import FieldLinear
-from submarine.ml.pytorch.layers.core import FieldEmbedding
-from submarine.ml.pytorch.layers.core import PairwiseInteraction
-from submarine.ml.pytorch.layers.core import DNN
-
import torch
from torch import nn
+
+from submarine.ml.pytorch.layers.core import (DNN, FieldEmbedding, FieldLinear,
+ PairwiseInteraction)
from submarine.ml.pytorch.model.base_pytorch_model import BasePyTorchModel
class DeepFM(BasePyTorchModel):
+
def model_fn(self, params):
super().model_fn(params)
return _DeepFM(**self.params['model']['kwargs'])
class _DeepFM(nn.Module):
- def __init__(self, field_dims, embedding_dim, out_features,
- hidden_units, dropout_rates, **kwargs):
+
+ def __init__(self, field_dims, embedding_dim, out_features, hidden_units,
+ dropout_rates, **kwargs):
super().__init__()
- self.field_linear = FieldLinear(
- field_dims=field_dims, out_features=out_features)
- self.field_embedding = FieldEmbedding(
- field_dims=field_dims, embedding_dim=embedding_dim)
+ self.field_linear = FieldLinear(field_dims=field_dims,
+ out_features=out_features)
+ self.field_embedding = FieldEmbedding(field_dims=field_dims,
+ embedding_dim=embedding_dim)
self.pairwise_interaction = PairwiseInteraction()
- self.dnn = DNN(
- in_features=len(field_dims)*embedding_dim,
- out_features=out_features,
- hidden_units=hidden_units,
- dropout_rates=dropout_rates
- )
+ self.dnn = DNN(in_features=len(field_dims) * embedding_dim,
+ out_features=out_features,
+ hidden_units=hidden_units,
+ dropout_rates=dropout_rates)
def forward(self, x):
"""
:param x: torch.LongTensor (batch_size, num_fields)
"""
- emb = self.field_embedding(
- x) # (batch_size, num_fields, embedding_dim)
+ emb = self.field_embedding(x) # (batch_size, num_fields,
embedding_dim)
linear_logit = self.field_linear(x)
fm_logit = self.pairwise_interaction(emb)
deep_logit = self.dnn(torch.flatten(emb, start_dim=1))
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/parameters.py
b/submarine-sdk/pysubmarine/submarine/ml/pytorch/parameters.py
index 5a74c2e..4619460 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/parameters.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/parameters.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
default_parameters = {
"output": {
"save_model_dir": "./output",
diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/registries.py
b/submarine-sdk/pysubmarine/submarine/ml/pytorch/registries.py
index 017be19..2df1609 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/registries.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/registries.py
@@ -17,6 +17,4 @@ from .input.libsvm_dataset import libsvm_input_fn
LIBSVM = "libsvm"
-input_fn_registry = {
- LIBSVM: libsvm_input_fn
-}
+input_fn_registry = {LIBSVM: libsvm_input_fn}
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
index ec9ec23..f32285f 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/input/input.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+
import tensorflow as tf
logger = logging.getLogger(__name__)
@@ -21,15 +22,24 @@ logger = logging.getLogger(__name__)
AUTOTUNE = tf.data.experimental.AUTOTUNE
-def libsvm_input_fn(filepath, batch_size=256, num_epochs=3, # pylint:
disable=W0613
- perform_shuffle=False, delimiter=" ", **kwargs):
+def libsvm_input_fn(
+ filepath,
+ batch_size=256,
+ num_epochs=3, # pylint: disable=W0613
+ perform_shuffle=False,
+ delimiter=" ",
+ **kwargs):
+
def _input_fn():
+
def decode_libsvm(line):
columns = tf.string_split([line], delimiter)
labels = tf.string_to_number(columns.values[0],
out_type=tf.float32)
splits = tf.string_split(columns.values[1:], ':')
id_vals = tf.reshape(splits.values, splits.dense_shape)
- feat_ids, feat_vals = tf.split(id_vals, num_or_size_splits=2,
axis=1)
+ feat_ids, feat_vals = tf.split(id_vals,
+ num_or_size_splits=2,
+ axis=1)
feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels
@@ -44,4 +54,5 @@ def libsvm_input_fn(filepath, batch_size=256, num_epochs=3,
# pylint: disable=W
dataset = dataset.batch(batch_size)
return dataset
+
return _input_fn
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
index bf72185..f3dd83e 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/layers/core.py
@@ -17,17 +17,34 @@ import tensorflow as tf
def batch_norm_layer(x, train_phase, scope_bn, batch_norm_decay):
- bn_train = tf.contrib.layers.batch_norm(x, decay=batch_norm_decay,
center=True, scale=True,
- updates_collections=None,
is_training=True,
- reuse=None, scope=scope_bn)
- bn_infer = tf.contrib.layers.batch_norm(x, decay=batch_norm_decay,
center=True, scale=True,
- updates_collections=None,
is_training=False,
- reuse=True, scope=scope_bn)
- return tf.cond(tf.cast(train_phase, tf.bool), lambda: bn_train, lambda:
bn_infer)
-
-
-def dnn_layer(inputs, estimator_mode, batch_norm, deep_layers, dropout,
batch_norm_decay=0.9,
- l2_reg=0, **kwargs):
+ bn_train = tf.contrib.layers.batch_norm(x,
+ decay=batch_norm_decay,
+ center=True,
+ scale=True,
+ updates_collections=None,
+ is_training=True,
+ reuse=None,
+ scope=scope_bn)
+ bn_infer = tf.contrib.layers.batch_norm(x,
+ decay=batch_norm_decay,
+ center=True,
+ scale=True,
+ updates_collections=None,
+ is_training=False,
+ reuse=True,
+ scope=scope_bn)
+ return tf.cond(tf.cast(train_phase, tf.bool), lambda: bn_train,
+ lambda: bn_infer)
+
+
+def dnn_layer(inputs,
+ estimator_mode,
+ batch_norm,
+ deep_layers,
+ dropout,
+ batch_norm_decay=0.9,
+ l2_reg=0,
+ **kwargs):
"""
The Multi Layer Percetron
:param inputs: A tensor of at least rank 2 and static value for the last
dimension; i.e.
@@ -51,18 +68,23 @@ def dnn_layer(inputs, estimator_mode, batch_norm,
deep_layers, dropout, batch_no
for i in range(len(deep_layers)):
deep_inputs = tf.contrib.layers.fully_connected(
- inputs=inputs, num_outputs=deep_layers[i],
+ inputs=inputs,
+ num_outputs=deep_layers[i],
weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg),
scope='mlp%d' % i)
if batch_norm:
deep_inputs = batch_norm_layer(
- deep_inputs, train_phase=train_phase,
- scope_bn='bn_%d' % i, batch_norm_decay=batch_norm_decay)
+ deep_inputs,
+ train_phase=train_phase,
+ scope_bn='bn_%d' % i,
+ batch_norm_decay=batch_norm_decay)
if estimator_mode == tf.estimator.ModeKeys.TRAIN:
deep_inputs = tf.nn.dropout(deep_inputs, keep_prob=dropout[i])
deep_out = tf.contrib.layers.fully_connected(
- inputs=deep_inputs, num_outputs=1, activation_fn=tf.identity,
+ inputs=deep_inputs,
+ num_outputs=1,
+ activation_fn=tf.identity,
weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg),
scope='deep_out')
deep_out = tf.reshape(deep_out, shape=[-1])
@@ -85,18 +107,27 @@ def linear_layer(features, feature_size, field_size,
l2_reg=0, **kwargs):
regularizer = tf.contrib.layers.l2_regularizer(l2_reg)
with tf.variable_scope("LinearLayer_Layer"):
- linear_bias = tf.get_variable(name='linear_bias', shape=[1],
+ linear_bias = tf.get_variable(name='linear_bias',
+ shape=[1],
initializer=tf.constant_initializer(0.0))
- linear_weight = tf.get_variable(name='linear_weight',
shape=[feature_size],
-
initializer=tf.glorot_normal_initializer(),
- regularizer=regularizer)
+ linear_weight = tf.get_variable(
+ name='linear_weight',
+ shape=[feature_size],
+ initializer=tf.glorot_normal_initializer(),
+ regularizer=regularizer)
feat_weights = tf.nn.embedding_lookup(linear_weight, feat_ids)
- linear_out = tf.reduce_sum(tf.multiply(feat_weights, feat_vals), 1) +
linear_bias
+ linear_out = tf.reduce_sum(tf.multiply(feat_weights, feat_vals),
+ 1) + linear_bias
return linear_out
-def embedding_layer(features, feature_size, field_size, embedding_size,
l2_reg=0, **kwargs):
+def embedding_layer(features,
+ feature_size,
+ field_size,
+ embedding_size,
+ l2_reg=0,
+ **kwargs):
"""
Turns positive integers (indexes) into dense vectors of fixed size.
eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
@@ -114,10 +145,11 @@ def embedding_layer(features, feature_size, field_size,
embedding_size, l2_reg=0
with tf.variable_scope("Embedding_Layer"):
regularizer = tf.contrib.layers.l2_regularizer(l2_reg)
- embedding_dict = tf.get_variable(name='embedding_dict',
- shape=[feature_size, embedding_size],
-
initializer=tf.glorot_normal_initializer(),
- regularizer=regularizer)
+ embedding_dict = tf.get_variable(
+ name='embedding_dict',
+ shape=[feature_size, embedding_size],
+ initializer=tf.glorot_normal_initializer(),
+ regularizer=regularizer)
embeddings = tf.nn.embedding_lookup(embedding_dict, feat_ids)
feat_vals = tf.reshape(feat_vals, shape=[-1, field_size, 1])
embedding_out = tf.multiply(embeddings, feat_vals)
diff --git
a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/base_tf_model.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/base_tf_model.py
index 8172293..118f876 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/base_tf_model.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/base_tf_model.py
@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from abc import ABC
import logging
-import tensorflow as tf
+from abc import ABC
+
import numpy as np
+import tensorflow as tf
+
from submarine.ml.abstract_model import AbstractModel
-from submarine.ml.tensorflow.registries import input_fn_registry
from submarine.ml.tensorflow.parameters import default_parameters
-from submarine.utils.env import get_from_registry, get_from_dicts,
get_from_json
+from submarine.ml.tensorflow.registries import input_fn_registry
+from submarine.utils.env import (get_from_dicts, get_from_json,
+ get_from_registry)
from submarine.utils.tf_utils import get_tf_config
logger = logging.getLogger(__name__)
@@ -28,6 +31,7 @@ logger = logging.getLogger(__name__)
# pylint: disable=W0221
class BaseTFModel(AbstractModel, ABC):
+
def __init__(self, model_params=None, json_path=None):
super().__init__()
self.model_params = get_from_dicts(model_params, default_parameters)
@@ -37,9 +41,10 @@ class BaseTFModel(AbstractModel, ABC):
self.input_type = self.model_params['input']['type']
self.model_dir = self.model_params['output']['save_model_dir']
self.config = get_tf_config(self.model_params)
- self.model = tf.estimator.Estimator(
- model_fn=self.model_fn, model_dir=self.model_dir,
- params=self.model_params, config=self.config)
+ self.model = tf.estimator.Estimator(model_fn=self.model_fn,
+ model_dir=self.model_dir,
+ params=self.model_params,
+ config=self.config)
def train(self, train_input_fn=None, eval_input_fn=None, **kwargs):
"""
@@ -51,19 +56,18 @@ class BaseTFModel(AbstractModel, ABC):
if train_input_fn is None:
train_input_fn = get_from_registry(
self.input_type, input_fn_registry)(
- filepath=self.model_params['input']['train_data'],
- **self.model_params['training'])
+ filepath=self.model_params['input']['train_data'],
+ **self.model_params['training'])
if eval_input_fn is None:
eval_input_fn = get_from_registry(
self.input_type, input_fn_registry)(
- filepath=self.model_params['input']['valid_data'],
- **self.model_params['training'])
+ filepath=self.model_params['input']['valid_data'],
+ **self.model_params['training'])
- train_spec = tf.estimator.TrainSpec(
- input_fn=train_input_fn)
- eval_spec = tf.estimator.EvalSpec(
- input_fn=eval_input_fn)
- tf.estimator.train_and_evaluate(self.model, train_spec, eval_spec,
**kwargs)
+ train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
+ eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+ tf.estimator.train_and_evaluate(self.model, train_spec, eval_spec,
+ **kwargs)
def evaluate(self, eval_input_fn=None, **kwargs):
"""
@@ -76,8 +80,8 @@ class BaseTFModel(AbstractModel, ABC):
if eval_input_fn is None:
eval_input_fn = get_from_registry(
self.input_type, input_fn_registry)(
- filepath=self.model_params['input']['valid_data'],
- **self.model_params['training'])
+ filepath=self.model_params['input']['valid_data'],
+ **self.model_params['training'])
return self.model.evaluate(input_fn=eval_input_fn, **kwargs)
@@ -91,21 +95,18 @@ class BaseTFModel(AbstractModel, ABC):
if predict_input_fn is None:
predict_input_fn = get_from_registry(
self.input_type, input_fn_registry)(
- filepath=self.model_params['input']['test_data'],
- **self.model_params['training'])
+ filepath=self.model_params['input']['test_data'],
+ **self.model_params['training'])
return self.model.predict(input_fn=predict_input_fn, **kwargs)
def _sanity_checks(self):
assert 'input' in self.model_params, (
- 'Does not define any input parameters'
- )
+ 'Does not define any input parameters')
assert 'type' in self.model_params['input'], (
- 'Does not define any input type'
- )
+ 'Does not define any input type')
assert 'output' in self.model_params, (
- 'Does not define any output parameters'
- )
+ 'Does not define any output parameters')
def model_fn(self, features, labels, mode, params):
seed = params["training"]["seed"]
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/deepfm.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/deepfm.py
index e88ccd0..844b9a6 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/deepfm.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/deepfm.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Tensorflow implementation of DeepFM
@@ -26,15 +25,19 @@ Reference:
"""
import logging
+
import tensorflow as tf
+
+from submarine.ml.tensorflow.layers.core import (dnn_layer, embedding_layer,
+ fm_layer, linear_layer)
from submarine.ml.tensorflow.model.base_tf_model import BaseTFModel
-from submarine.ml.tensorflow.layers.core import fm_layer, linear_layer,
dnn_layer, embedding_layer
from submarine.utils.tf_utils import get_estimator_spec
logger = logging.getLogger(__name__)
class DeepFM(BaseTFModel):
+
def model_fn(self, features, labels, mode, params):
super().model_fn(features, labels, mode, params)
@@ -45,7 +48,8 @@ class DeepFM(BaseTFModel):
field_size = params['training']['field_size']
embedding_size = params['training']['embedding_size']
- deep_inputs = tf.reshape(embedding_outputs, shape=[-1, field_size *
embedding_size])
+ deep_inputs = tf.reshape(embedding_outputs,
+ shape=[-1, field_size * embedding_size])
deep_logit = dnn_layer(deep_inputs, mode, **params['training'])
with tf.variable_scope("DeepFM_out"):
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/fm.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/fm.py
index e3433dc..74b7e8d 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/fm.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/fm.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
TensorFlow implementation of FM
@@ -22,8 +21,11 @@ Reference:
"""
import logging
+
import tensorflow as tf
-from submarine.ml.tensorflow.layers.core import linear_layer, fm_layer,
embedding_layer
+
+from submarine.ml.tensorflow.layers.core import (embedding_layer, fm_layer,
+ linear_layer)
from submarine.ml.tensorflow.model.base_tf_model import BaseTFModel
from submarine.utils.tf_utils import get_estimator_spec
@@ -31,6 +33,7 @@ logger = logging.getLogger(__name__)
class FM(BaseTFModel):
+
def model_fn(self, features, labels, mode, params):
super().model_fn(features, labels, mode, params)
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/nfm.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/nfm.py
index d592a8c..2ae4d7a 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/nfm.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/model/nfm.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
TensorFlow implementation of NFM
@@ -23,23 +22,26 @@ Reference:
"""
import logging
+
import tensorflow as tf
+
+from submarine.ml.tensorflow.layers.core import (bilinear_layer, dnn_layer,
+ embedding_layer, linear_layer)
from submarine.ml.tensorflow.model.base_tf_model import BaseTFModel
-from submarine.ml.tensorflow.layers.core import dnn_layer, bilinear_layer,\
- linear_layer, embedding_layer
from submarine.utils.tf_utils import get_estimator_spec
logger = logging.getLogger(__name__)
class NFM(BaseTFModel):
+
def model_fn(self, features, labels, mode, params):
super().model_fn(features, labels, mode, params)
linear_logit = linear_layer(features, **params['training'])
embedding_outputs = embedding_layer(features, **params['training'])
deep_inputs = bilinear_layer(embedding_outputs, **params['training'])
- deep_logit = dnn_layer(deep_inputs, mode, **params['training'])
+ deep_logit = dnn_layer(deep_inputs, mode, **params['training'])
with tf.variable_scope("NFM_out"):
logit = linear_logit + deep_logit
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
index 6719e10..89bc28a 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/optimizer.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+
import tensorflow as tf
logger = logging.getLogger(__name__)
@@ -32,12 +33,15 @@ def get_optimizer(optimizer_key, learning_rate):
if optimizer_key == OptimizerKey.ADAM:
op = tf.train.AdamOptimizer(learning_rate=learning_rate,
- beta1=0.9, beta2=0.999, epsilon=1e-8)
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8)
elif optimizer_key == OptimizerKey.ADAGRAD:
- op = tf.train.AdagradOptimizer(
- learning_rate=learning_rate, initial_accumulator_value=1e-8)
+ op = tf.train.AdagradOptimizer(learning_rate=learning_rate,
+ initial_accumulator_value=1e-8)
elif optimizer_key == OptimizerKey.MOMENTUM:
- op = tf.train.MomentumOptimizer(learning_rate=learning_rate,
momentum=0.95)
+ op = tf.train.MomentumOptimizer(learning_rate=learning_rate,
+ momentum=0.95)
elif optimizer_key == OptimizerKey.FTRL:
op = tf.train.FtrlOptimizer(learning_rate)
else:
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
index 18d8604..a35312d 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/parameters.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
default_parameters = {
"output": {
"save_model_dir": "./experiment",
diff --git a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/registries.py
b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/registries.py
index 74f5ca0..0ce59af 100644
--- a/submarine-sdk/pysubmarine/submarine/ml/tensorflow/registries.py
+++ b/submarine-sdk/pysubmarine/submarine/ml/tensorflow/registries.py
@@ -17,6 +17,4 @@ from submarine.ml.tensorflow.input import libsvm_input_fn
LIBSVM = "libsvm"
-input_fn_registry = {
- LIBSVM: libsvm_input_fn
-}
+input_fn_registry = {LIBSVM: libsvm_input_fn}
diff --git a/submarine-sdk/pysubmarine/submarine/store/database/db_types.py
b/submarine-sdk/pysubmarine/submarine/store/database/db_types.py
index aff53ee..f788371 100644
--- a/submarine-sdk/pysubmarine/submarine/store/database/db_types.py
+++ b/submarine-sdk/pysubmarine/submarine/store/database/db_types.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Set of SQLAlchemy database schemas supported in Submarine for tracking server
backends.
"""
@@ -22,9 +21,4 @@ MYSQL = 'mysql'
SQLITE = 'sqlite'
MSSQL = 'mssql'
-DATABASE_ENGINES = [
- POSTGRES,
- MYSQL,
- SQLITE,
- MSSQL
-]
+DATABASE_ENGINES = [POSTGRES, MYSQL, SQLITE, MSSQL]
diff --git a/submarine-sdk/pysubmarine/submarine/store/database/models.py
b/submarine-sdk/pysubmarine/submarine/store/database/models.py
index 6f501db..ab76885 100644
--- a/submarine-sdk/pysubmarine/submarine/store/database/models.py
+++ b/submarine-sdk/pysubmarine/submarine/store/database/models.py
@@ -14,12 +14,13 @@
# limitations under the License.
import time
+
import sqlalchemy as sa
-from sqlalchemy import (Column, String, BigInteger,
- PrimaryKeyConstraint, Boolean)
+from sqlalchemy import (BigInteger, Boolean, Column, PrimaryKeyConstraint,
+ String)
from sqlalchemy.ext.declarative import declarative_base
-from submarine.entities import (Metric, Param)
+from submarine.entities import Metric, Param
Base = declarative_base()
@@ -68,26 +69,31 @@ class SqlMetric(Base):
JOB NAME to which this metric belongs to: Part of *Primary Key* for
``metrics`` table.
"""
- __table_args__ = (
- PrimaryKeyConstraint('key', 'timestamp', 'worker_index', 'step',
'job_name',
- 'value', "is_nan", name='metric_pk'),
- )
+ __table_args__ = (PrimaryKeyConstraint('key',
+ 'timestamp',
+ 'worker_index',
+ 'step',
+ 'job_name',
+ 'value',
+ "is_nan",
+ name='metric_pk'),)
def __repr__(self):
- return '<SqlMetric({}, {}, {}, {}, {})>'.format(self.key, self.value,
self.worker_index,
- self.timestamp,
self.step)
+ return '<SqlMetric({}, {}, {}, {}, {})>'.format(self.key, self.value,
+ self.worker_index,
+ self.timestamp,
+ self.step)
def to_submarine_entity(self):
"""
Convert DB model to corresponding Submarine entity.
:return: :py:class:`submarine.entities.Metric`.
"""
- return Metric(
- key=self.key,
- value=self.value if not self.is_nan else float("nan"),
- worker_index=self.worker_index,
- timestamp=self.timestamp,
- step=self.step)
+ return Metric(key=self.key,
+ value=self.value if not self.is_nan else float("nan"),
+ worker_index=self.worker_index,
+ timestamp=self.timestamp,
+ step=self.step)
# +----------+-------+--------------+-----------------------+
@@ -120,19 +126,20 @@ class SqlParam(Base):
JOB NAME to which this parameter belongs to: Part of *Primary Key* for
``params`` table.
"""
- __table_args__ = (
- PrimaryKeyConstraint('key', 'job_name', 'worker_index',
name='param_pk'),
- )
+ __table_args__ = (PrimaryKeyConstraint('key',
+ 'job_name',
+ 'worker_index',
+ name='param_pk'),)
def __repr__(self):
- return '<SqlParam({}, {}, {})>'.format(self.key, self.value,
self.worker_index)
+ return '<SqlParam({}, {}, {})>'.format(self.key, self.value,
+ self.worker_index)
def to_submarine_entity(self):
"""
Convert DB model to corresponding submarine entity.
:return: :py:class:`submarine.entities.Param`.
"""
- return Param(
- key=self.key,
- value=self.value,
- worker_index=self.worker_index)
+ return Param(key=self.key,
+ value=self.value,
+ worker_index=self.worker_index)
diff --git a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
b/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
index bd29c02..38c5a17 100644
--- a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
@@ -14,15 +14,15 @@
# limitations under the License.
import logging
-from contextlib import contextmanager
import math
+from contextlib import contextmanager
+
import sqlalchemy
from submarine.exceptions import SubmarineException
-from submarine.utils import extract_db_type_from_uri
-from submarine.store.database.models import Base, SqlMetric, SqlParam
from submarine.store.abstract_store import AbstractStore
-
+from submarine.store.database.models import Base, SqlMetric, SqlParam
+from submarine.utils import extract_db_type_from_uri
_logger = logging.getLogger(__name__)
@@ -82,6 +82,7 @@ class SqlAlchemyStore(AbstractStore):
encountered, the session is rolled back. Finally, any session produced
by this factory is
automatically closed when the session's associated context is exited.
"""
+
@contextmanager
def make_managed_session():
"""Provide a transactional scope around a series of operations."""
@@ -136,18 +137,26 @@ class SqlAlchemyStore(AbstractStore):
value = float(metric.value)
with self.ManagedSessionMaker() as session:
try:
- self._get_or_create(model=SqlMetric, job_name=job_name,
key=metric.key,
- value=value,
worker_index=metric.worker_index,
- timestamp=metric.timestamp,
step=metric.step,
- session=session, is_nan=is_nan)
+ self._get_or_create(model=SqlMetric,
+ job_name=job_name,
+ key=metric.key,
+ value=value,
+ worker_index=metric.worker_index,
+ timestamp=metric.timestamp,
+ step=metric.step,
+ session=session,
+ is_nan=is_nan)
except sqlalchemy.exc.IntegrityError:
session.rollback()
def log_param(self, job_name, param):
with self.ManagedSessionMaker() as session:
try:
- self._get_or_create(model=SqlParam, job_name=job_name,
session=session,
- key=param.key, value=param.value,
+ self._get_or_create(model=SqlParam,
+ job_name=job_name,
+ session=session,
+ key=param.key,
+ value=param.value,
worker_index=param.worker_index)
session.commit()
except sqlalchemy.exc.IntegrityError:
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
b/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
index a1cd9aa..097d7fc 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
@@ -14,8 +14,8 @@
# limitations under the License.
from submarine.tracking.client import SubmarineClient
-from submarine.tracking.utils import set_tracking_uri, get_tracking_uri,
_TRACKING_URI_ENV_VAR, \
- _JOB_NAME_ENV_VAR
+from submarine.tracking.utils import (_JOB_NAME_ENV_VAR, _TRACKING_URI_ENV_VAR,
+ get_tracking_uri, set_tracking_uri)
__all__ = [
"SubmarineClient",
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py
b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index 0f04743..5bca2dc 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
-from submarine.entities import Param, Metric
+
+from submarine.entities import Metric, Param
from submarine.tracking import utils
from submarine.utils.validation import validate_metric, validate_param
class SubmarineClient(object):
-
"""
Client of an submarine Tracking Server that creates and manages
experiments and runs.
"""
@@ -34,7 +34,13 @@ class SubmarineClient(object):
self.tracking_uri = tracking_uri or utils.get_tracking_uri()
self.store = utils.get_sqlalchemy_store(self.tracking_uri)
- def log_metric(self, job_name, key, value, worker_index, timestamp=None,
step=None):
+ def log_metric(self,
+ job_name,
+ key,
+ value,
+ worker_index,
+ timestamp=None,
+ step=None):
"""
Log a metric against the run ID.
:param job_name: The job name to which the metric should be logged.
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
index f0689e5..2cbb4c1 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
@@ -12,20 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Internal module implementing the fluent API, allowing management of an active
Submarine run. This module is exposed to users at the top-level
:py:mod:`submarine` module.
"""
from __future__ import print_function
-from submarine.tracking.client import SubmarineClient
-from submarine.tracking.utils import get_job_name
-import time
import logging
import random
import string
+import time
+from submarine.tracking.client import SubmarineClient
+from submarine.tracking.utils import get_job_name
_RUN_ID_ENV_VAR = "SUBMARINE_RUN_ID"
_active_run_stack = []
@@ -62,5 +61,5 @@ def log_metric(key, value, worker_index, step=None):
:param step: Metric step (int). Defaults to zero if unspecified.
"""
job_name = get_job_name()
- SubmarineClient().log_metric(
- job_name, key, value, worker_index, int(time.time() * 1000), step or 0)
+ SubmarineClient().log_metric(job_name, key, value, worker_index,
+ int(time.time() * 1000), step or 0)
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/utils.py
b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
index 952de07..c8ea3cc 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import print_function
+
from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
from submarine.store.sqlalchemy_store import SqlAlchemyStore
from submarine.utils import env
diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
index b36ce34..8a94b60 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
@@ -14,6 +14,7 @@
# limitations under the License.
from six.moves import urllib
+
from submarine.exceptions import SubmarineException
@@ -30,7 +31,8 @@ def extract_db_type_from_uri(db_uri):
elif scheme_plus_count == 1:
db_type, _ = scheme.split('+')
else:
- error_msg = "Invalid database URI: '%s'. %s" % (db_uri,
'INVALID_DB_URI_MSG')
+ error_msg = "Invalid database URI: '%s'. %s" % (db_uri,
+ 'INVALID_DB_URI_MSG')
raise SubmarineException(error_msg)
return db_type
diff --git a/submarine-sdk/pysubmarine/submarine/utils/env.py
b/submarine-sdk/pysubmarine/submarine/utils/env.py
index 046c2e6..56d133e 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/env.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/env.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-import json
import copy
+import json
+import os
from collections import Mapping
@@ -64,8 +64,8 @@ def get_from_dicts(params, defaultParams):
dct = copy.deepcopy(defaultParams)
for k, _ in params.items():
- if (k in dct and isinstance(dct[k], dict)
- and isinstance(defaultParams[k], Mapping)):
+ if (k in dct and isinstance(dct[k], dict) and
+ isinstance(defaultParams[k], Mapping)):
dct[k] = get_from_dicts(params[k], dct[k])
else:
dct[k] = params[k]
@@ -78,8 +78,5 @@ def get_from_registry(key, registry):
if key in registry:
return registry[key]
else:
- raise ValueError(
- 'Key {} not supported, available options: {}'.format(
- key, registry.keys()
- )
- )
+ raise ValueError('Key {} not supported, available options: {}'.format(
+ key, registry.keys()))
diff --git a/submarine-sdk/pysubmarine/submarine/utils/fileio.py
b/submarine-sdk/pysubmarine/submarine/utils/fileio.py
index 410dbe3..699e1a5 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/fileio.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/fileio.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from pyarrow import fs
-
import io
-from urllib.parse import urlparse
import os
from enum import Enum
+from urllib.parse import urlparse
+
+from pyarrow import fs
class _Scheme(Enum):
diff --git a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
index 50c6f89..0437469 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/rest_utils.py
@@ -13,16 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from submarine.exceptions import SubmarineException, RestException
import json
-import requests
import logging
+import requests
+
+from submarine.exceptions import RestException, SubmarineException
+
_logger = logging.getLogger(__name__)
-def http_request(base_url, endpoint, method, json_body,
- timeout=60, headers=None, **kwargs):
+def http_request(base_url,
+ endpoint,
+ method,
+ json_body,
+ timeout=60,
+ headers=None,
+ **kwargs):
"""
Perform requests.
:param base_url: http request base url containing hostname and port. e.g.
https://submarine:8088
@@ -34,15 +41,20 @@ def http_request(base_url, endpoint, method, json_body,
:return:
"""
method = method.upper()
- assert method in ['GET', 'HEAD', 'DELETE', 'POST', 'PUT',
- 'PATCH', 'OPTIONS']
+ assert method in [
+ 'GET', 'HEAD', 'DELETE', 'POST', 'PUT', 'PATCH', 'OPTIONS'
+ ]
headers = headers or {}
if 'Content-Type' not in headers:
headers['Content-Type'] = 'application/json'
url = base_url + endpoint
- response = requests.request(url=url, method=method, json=json_body,
headers=headers,
- timeout=timeout, **kwargs)
+ response = requests.request(url=url,
+ method=method,
+ json=json_body,
+ headers=headers,
+ timeout=timeout,
+ **kwargs)
verify_rest_response(response, endpoint)
response = json.loads(response.text)
@@ -66,5 +78,6 @@ def verify_rest_response(response, endpoint):
else:
base_msg = "API request to endpoint %s failed with error code " \
"%s != 200" % (endpoint, response.status_code)
- raise SubmarineException("%s. Response body: '%s'" % (base_msg,
response.text))
+ raise SubmarineException("%s. Response body: '%s'" %
+ (base_msg, response.text))
return response
diff --git a/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
b/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
index fcfca4a..96da0ce 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/tf_utils.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import json
+import os
+
import tensorflow as tf
+
from submarine.ml.tensorflow.optimizer import get_optimizer
@@ -28,13 +30,17 @@ def _get_session_config_from_env_var(params):
and 'index' in tf_config['task']:
# Master should only communicate with itself and ps.
if tf_config['task']['type'] == 'master':
- return tf.ConfigProto(device_filters=['/job:ps', '/job:master'],
-
intra_op_parallelism_threads=params["resource"]['num_thread'],
-
inter_op_parallelism_threads=params["resource"]['num_thread'])
+ return tf.ConfigProto(
+ device_filters=['/job:ps', '/job:master'],
+ intra_op_parallelism_threads=params["resource"]['num_thread'],
+ inter_op_parallelism_threads=params["resource"]['num_thread'])
# Worker should only communicate with itself and ps.
elif tf_config['task']['type'] == 'worker':
return tf.ConfigProto( # gpu_options=gpu_options,
- device_filters=['/job:ps', '/job:worker/task:%d' %
tf_config['task']['index']],
+ device_filters=[
+ '/job:ps',
+ '/job:worker/task:%d' % tf_config['task']['index']
+ ],
intra_op_parallelism_threads=params["resource"]['num_thread'],
inter_op_parallelism_threads=params["resource"]['num_thread'])
return None
@@ -52,8 +58,10 @@ def get_tf_config(params):
if params["training"]['mode'] == 'local': # local mode
tf_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(
- device_count={'GPU': params["resource"]['num_gpu'],
- 'CPU': params["resource"]['num_cpu']},
+ device_count={
+ 'GPU': params["resource"]['num_gpu'],
+ 'CPU': params["resource"]['num_cpu']
+ },
intra_op_parallelism_threads=params["resource"]['num_thread'],
inter_op_parallelism_threads=params["resource"]['num_thread']),
log_step_count_steps=params["training"]['log_steps'],
@@ -62,8 +70,10 @@ def get_tf_config(params):
elif params["training"]['mode'] == 'distributed':
tf_config = tf.estimator.RunConfig(
experimental_distribute=tf.contrib.distribute.DistributeConfig(
-
train_distribute=tf.contrib.distribute.ParameterServerStrategy(),
-
eval_distribute=tf.contrib.distribute.ParameterServerStrategy()),
+ train_distribute=tf.contrib.distribute.ParameterServerStrategy(
+ ),
+ eval_distribute=tf.contrib.distribute.ParameterServerStrategy(
+ )),
session_config=_get_session_config_from_env_var(params),
save_summary_steps=params["training"]['log_steps'],
log_step_count_steps=params["training"]['log_steps'])
@@ -90,17 +100,18 @@ def get_estimator_spec(logit, labels, mode, params):
predictions = {"probabilities": output}
export_outputs = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- tf.estimator.export.PredictOutput(predictions)}
+ tf.estimator.export.PredictOutput(predictions)
+ }
# Provide an estimator spec for `ModeKeys.PREDICT`
if mode == tf.estimator.ModeKeys.PREDICT:
- return tf.estimator.EstimatorSpec(
- mode=mode,
- predictions=predictions,
- export_outputs=export_outputs)
+ return tf.estimator.EstimatorSpec(mode=mode,
+ predictions=predictions,
+ export_outputs=export_outputs)
with tf.name_scope("Loss"):
loss = tf.reduce_mean(
- tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,
labels=labels))
+ tf.nn.sigmoid_cross_entropy_with_logits(logits=logit,
+ labels=labels))
# Provide an estimator spec for `ModeKeys.EVAL`
eval_metric_ops = {}
@@ -110,11 +121,10 @@ def get_estimator_spec(logit, labels, mode, params):
raise TypeError("Invalid metric :", metric)
if mode == tf.estimator.ModeKeys.EVAL:
- return tf.estimator.EstimatorSpec(
- mode=mode,
- predictions=predictions,
- loss=loss,
- eval_metric_ops=eval_metric_ops)
+ return tf.estimator.EstimatorSpec(mode=mode,
+ predictions=predictions,
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
with tf.name_scope("Train"):
op = get_optimizer(optimizer, learning_rate)
@@ -122,8 +132,7 @@ def get_estimator_spec(logit, labels, mode, params):
# Provide an estimator spec for `ModeKeys.TRAIN` modes
if mode == tf.estimator.ModeKeys.TRAIN:
- return tf.estimator.EstimatorSpec(
- mode=mode,
- predictions=predictions,
- loss=loss,
- train_op=train_op)
+ return tf.estimator.EstimatorSpec(mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op)
diff --git a/submarine-sdk/pysubmarine/submarine/utils/validation.py
b/submarine-sdk/pysubmarine/submarine/utils/validation.py
index 25733b1..304adac 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/validation.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/validation.py
@@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Utilities for validating user inputs such as metric names and parameter names.
"""
import numbers
-import re
import posixpath
+import re
from submarine.exceptions import SubmarineException
from submarine.store.database.db_types import DATABASE_ENGINES
@@ -28,13 +27,12 @@ _VALID_PARAM_AND_METRIC_NAMES = re.compile(r"^[/\w.\- ]*$")
MAX_ENTITY_KEY_LENGTH = 250
MAX_PARAM_VAL_LENGTH = 250
-
_BAD_CHARACTERS_MESSAGE = (
"Names may only contain alphanumerics, underscores (_), dashes (-),
periods (.),"
- " spaces ( ), and slashes (/)."
-)
+ " spaces ( ), and slashes (/).")
-_UNSUPPORTED_DB_TYPE_MSG = "Supported database engines are {%s}" % ',
'.join(DATABASE_ENGINES)
+_UNSUPPORTED_DB_TYPE_MSG = "Supported database engines are {%s}" % ', '.join(
+ DATABASE_ENGINES)
def bad_path_message(name):
@@ -46,27 +44,31 @@ def bad_path_message(name):
def path_not_unique(name):
norm = posixpath.normpath(name)
- return norm != name or norm == '.' or norm.startswith('..') or
norm.startswith('/')
+ return norm != name or norm == '.' or norm.startswith(
+ '..') or norm.startswith('/')
def _validate_param_name(name):
"""Check that `name` is a valid parameter name and raise an exception if
it isn't."""
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
raise SubmarineException(
- "Invalid parameter name: '%s'. %s" % (name,
_BAD_CHARACTERS_MESSAGE),)
+ "Invalid parameter name: '%s'. %s" %
+ (name, _BAD_CHARACTERS_MESSAGE),)
if path_not_unique(name):
- raise SubmarineException(
- "Invalid parameter name: '%s'. %s" % (name,
bad_path_message(name)))
+ raise SubmarineException("Invalid parameter name: '%s'. %s" %
+ (name, bad_path_message(name)))
def _validate_metric_name(name):
"""Check that `name` is a valid metric name and raise an exception if it
isn't."""
if not _VALID_PARAM_AND_METRIC_NAMES.match(name):
- raise SubmarineException("Invalid metric name: '%s'. %s" % (name,
_BAD_CHARACTERS_MESSAGE),)
+ raise SubmarineException(
+ "Invalid metric name: '%s'. %s" % (name, _BAD_CHARACTERS_MESSAGE),)
if path_not_unique(name):
- raise SubmarineException("Invalid metric name: '%s'. %s" % (name,
bad_path_message(name)))
+ raise SubmarineException("Invalid metric name: '%s'. %s" %
+ (name, bad_path_message(name)))
def _validate_length_limit(entity_name, limit, value):
@@ -111,5 +113,6 @@ def validate_param(key, value):
def _validate_db_type_string(db_type):
"""validates db_type parsed from DB URI is supported"""
if db_type not in DATABASE_ENGINES:
- error_msg = "Invalid database engine: '%s'. '%s'" % (db_type,
_UNSUPPORTED_DB_TYPE_MSG)
+ error_msg = "Invalid database engine: '%s'. '%s'" % (
+ db_type, _UNSUPPORTED_DB_TYPE_MSG)
raise SubmarineException(error_msg)
diff --git a/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py
b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py
index b6f6a88..997a709 100644
--- a/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py
+++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/model/conftest.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import pytest
-import os
# noqa
LIBSVM_DATA = """
@@ -65,10 +65,11 @@ def get_model_param(tmpdir):
"model": {
"name": "ctr.deepfm",
"kwargs": {
- "field_dims":
- [15, 52, 30, 19, 111, 51, 26, 19, 53, 5, 13, 8, 23, 21,
- 77, 25, 39, 11, 8, 61, 15, 3, 34, 75, 30, 79, 11,
- 85, 37, 10, 94, 19, 5, 32, 6, 12, 42, 18, 23],
+ "field_dims": [
+ 15, 52, 30, 19, 111, 51, 26, 19, 53, 5, 13, 8, 23, 21, 77,
+ 25, 39, 11, 8, 61, 15, 3, 34, 75, 30, 79, 11, 85, 37, 10,
+ 94, 19, 5, 32, 6, 12, 42, 18, 23
+ ],
"out_features": 1,
"embedding_dim": 16,
"hidden_units": [400, 400],
diff --git a/submarine-sdk/pysubmarine/tests/ml/pytorch/test_loss_pytorch.py
b/submarine-sdk/pysubmarine/tests/ml/pytorch/test_loss_pytorch.py
index b0af1ce..bacdd1b 100644
--- a/submarine-sdk/pysubmarine/tests/ml/pytorch/test_loss_pytorch.py
+++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/test_loss_pytorch.py
@@ -14,6 +14,7 @@
# limitations under the License.
import pytest
+
from submarine.ml.pytorch.loss import get_loss_fn
diff --git a/submarine-sdk/pysubmarine/tests/ml/pytorch/test_metric_pytorch.py
b/submarine-sdk/pysubmarine/tests/ml/pytorch/test_metric_pytorch.py
index 4ece597..3a42dba 100644
--- a/submarine-sdk/pysubmarine/tests/ml/pytorch/test_metric_pytorch.py
+++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/test_metric_pytorch.py
@@ -14,6 +14,7 @@
# limitations under the License.
import pytest
+
from submarine.ml.pytorch.metric import get_metric_fn
diff --git
a/submarine-sdk/pysubmarine/tests/ml/pytorch/test_optimizer_pytorch.py
b/submarine-sdk/pysubmarine/tests/ml/pytorch/test_optimizer_pytorch.py
index c8112f3..edb399f 100644
--- a/submarine-sdk/pysubmarine/tests/ml/pytorch/test_optimizer_pytorch.py
+++ b/submarine-sdk/pysubmarine/tests/ml/pytorch/test_optimizer_pytorch.py
@@ -14,6 +14,7 @@
# limitations under the License.
import pytest
+
from submarine.ml.pytorch.optimizer import get_optimizer
diff --git a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/conftest.py
b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/conftest.py
index 4db552e..17cf295 100644
--- a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/conftest.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/conftest.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
import os
+import pytest
+
LIBSVM_DATA = """1 1:0 2:0.051495 3:0.5 4:0.1 5:0.113437 6:0.874 7:0.01 8:0.08
9:0.028 10:0
1 1:1.35 2:0.031561 3:0.45 4:0.56 5:0.000031 6:0.056 7:0.27 8:0.58 9:0.056
10:0.166667
1 1:0.05 2:0.004983 3:0.19 4:0.14 5:0.000016 6:0.006 7:0.01 8:0.14 9:0.014
10:0.166667
diff --git
a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_base_tf_model.py
b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_base_tf_model.py
index 04542d4..7fb475a 100644
--- a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_base_tf_model.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_base_tf_model.py
@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from submarine.ml.tensorflow.model.base_tf_model import BaseTFModel
import pytest
+from submarine.ml.tensorflow.model.base_tf_model import BaseTFModel
+
def test_create_base_tf_model():
params = {"learning rate": 0.05}
- with pytest.raises(AssertionError, match="Does not define any input
parameters"):
+ with pytest.raises(AssertionError,
+ match="Does not define any input parameters"):
BaseTFModel(params)
params.update({'input': {'train_data': '/tmp/train.csv'}})
diff --git a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_deepfm.py
b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_deepfm.py
index 43cd141..55f8e37 100644
--- a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_deepfm.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_deepfm.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from submarine.ml.tensorflow.model import DeepFM
diff --git a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_fm.py
b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_fm.py
index 6663f9f..bedda94 100644
--- a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_fm.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_fm.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from submarine.ml.tensorflow.model import FM
diff --git a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_nfm.py
b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_nfm.py
index 89819b0..dab76b5 100644
--- a/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_nfm.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/model/test_nfm.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from submarine.ml.tensorflow.model import NFM
diff --git a/submarine-sdk/pysubmarine/tests/ml/tensorflow/test_optimizer.py
b/submarine-sdk/pysubmarine/tests/ml/tensorflow/test_optimizer.py
index bc860f6..96daf09 100644
--- a/submarine-sdk/pysubmarine/tests/ml/tensorflow/test_optimizer.py
+++ b/submarine-sdk/pysubmarine/tests/ml/tensorflow/test_optimizer.py
@@ -14,6 +14,7 @@
# limitations under the License.
import pytest
+
from submarine.ml.tensorflow.optimizer import get_optimizer
@@ -26,4 +27,5 @@ def test_get_optimizer():
for invalid_optimizer_key in invalid_optimizer_keys:
with pytest.raises(ValueError, match="Invalid optimizer_key :"):
- get_optimizer(optimizer_key=invalid_optimizer_key,
learning_rate=0.3)
+ get_optimizer(optimizer_key=invalid_optimizer_key,
+ learning_rate=0.3)
diff --git a/submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py
b/submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py
index f2d3263..db8b1ec 100644
--- a/submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py
@@ -13,23 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import submarine
+import time
+import unittest
from os import environ
+
+import submarine
+from submarine.entities import Metric, Param
+from submarine.store.database import models
from submarine.store.database.models import SqlMetric, SqlParam
from submarine.tracking import utils
-from submarine.store.database import models
-from submarine.entities import Metric, Param
-
-import time
-import unittest
JOB_NAME = "application_123456789"
class TestSqlAlchemyStore(unittest.TestCase):
+
def setUp(self):
submarine.set_tracking_uri(
-
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test")
+
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
+ )
self.tracking_uri = utils.get_tracking_uri()
self.store = utils.get_sqlalchemy_store(self.tracking_uri)
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
index 20efbca..76f6514 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
@@ -13,22 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import submarine
+import unittest
from os import environ
+
+import submarine
+from submarine.store.database import models
from submarine.store.database.models import SqlMetric, SqlParam
from submarine.tracking import utils
-from submarine.store.database import models
-
-import unittest
JOB_NAME = "application_123456789"
class TestTracking(unittest.TestCase):
+
def setUp(self):
environ["SUBMARINE_JOB_NAME"] = JOB_NAME
submarine.set_tracking_uri(
-
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test")
+
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
+ )
self.tracking_uri = utils.get_tracking_uri()
self.store = utils.get_sqlalchemy_store(self.tracking_uri)
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
index be057bf..33422b9 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
import os
+
+import mock
+
from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
from submarine.store.sqlalchemy_store import SqlAlchemyStore
-from submarine.tracking.utils import is_tracking_uri_set,
_TRACKING_URI_ENV_VAR, \
- get_tracking_uri, _JOB_NAME_ENV_VAR, get_job_name, get_sqlalchemy_store
+from submarine.tracking.utils import (_JOB_NAME_ENV_VAR, _TRACKING_URI_ENV_VAR,
+ get_job_name, get_sqlalchemy_store,
+ get_tracking_uri, is_tracking_uri_set)
def test_is_tracking_uri_set():
@@ -39,8 +42,7 @@ def test_get_tracking_uri():
def test_get_job_name():
env = {
- _JOB_NAME_ENV_VAR:
- "application_12346789",
+ _JOB_NAME_ENV_VAR: "application_12346789",
}
with mock.patch.dict(os.environ, env):
assert get_job_name() == "application_12346789"
@@ -49,9 +51,7 @@ def test_get_job_name():
def test_get_sqlalchemy_store():
patch_create_engine = mock.patch("sqlalchemy.create_engine")
uri = DEFAULT_SUBMARINE_JDBC_URL
- env = {
- _TRACKING_URI_ENV_VAR: uri
- }
+ env = {_TRACKING_URI_ENV_VAR: uri}
with mock.patch.dict(os.environ, env), patch_create_engine as
mock_create_engine, \
mock.patch("submarine.store.sqlalchemy_store.SqlAlchemyStore._initialize_tables"):
store = get_sqlalchemy_store(uri)
diff --git a/submarine-sdk/pysubmarine/tests/utils/test_env.py
b/submarine-sdk/pysubmarine/tests/utils/test_env.py
index b6919b0..eefd176 100644
--- a/submarine-sdk/pysubmarine/tests/utils/test_env.py
+++ b/submarine-sdk/pysubmarine/tests/utils/test_env.py
@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from submarine.utils.env import get_env, unset_variable
-from submarine.utils.env import get_from_dicts, get_from_json,
get_from_registry
-from os import environ
import json
+from os import environ
+
import pytest
+from submarine.utils.env import (get_env, get_from_dicts, get_from_json,
+ get_from_registry, unset_variable)
+
@pytest.fixture(scope="function")
def output_json_filepath():
diff --git a/submarine-sdk/pysubmarine/tests/utils/test_rest_utils.py
b/submarine-sdk/pysubmarine/tests/utils/test_rest_utils.py
index 62769d1..3a866e9 100644
--- a/submarine-sdk/pysubmarine/tests/utils/test_rest_utils.py
+++ b/submarine-sdk/pysubmarine/tests/utils/test_rest_utils.py
@@ -13,23 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import patch, Mock
-import pytest
import json
-from submarine.utils.rest_utils import http_request, verify_rest_response
+
+import pytest
+from mock import Mock, patch
+
from submarine.exceptions import RestException, SubmarineException
+from submarine.utils.rest_utils import http_request, verify_rest_response
def test_http_request():
- dummy_json = json.dumps({'result': {'jobId': 'job_1234567', 'name':
'submarine',
- 'identifier': 'test'}})
+ dummy_json = json.dumps({
+ 'result': {
+ 'jobId': 'job_1234567',
+ 'name': 'submarine',
+ 'identifier': 'test'
+ }
+ })
with patch('requests.request') as mock_requests:
mock_requests.return_value.text = dummy_json
mock_requests.return_value.status_code = 200
- result = http_request('http://submarine:8080', json_body='dummy',
- endpoint='/api/v1/jobs', method='POST')
+ result = http_request('http://submarine:8080',
+ json_body='dummy',
+ endpoint='/api/v1/jobs',
+ method='POST')
assert result['jobId'] == 'job_1234567'
assert result['name'] == 'submarine'
@@ -54,6 +63,7 @@ def test_verify_rest_response():
# Test response status code not equal 200(OK) and response can not parse
as JSON
mock_json_body = 'test, 123'
mock_response.text = mock_json_body
- with pytest.raises(SubmarineException, match='API request to endpoint
/api/v1/jobs failed '
- 'with error code 400 != 200'):
+ with pytest.raises(SubmarineException,
+ match='API request to endpoint /api/v1/jobs failed '
+ 'with error code 400 != 200'):
verify_rest_response(mock_response, '/api/v1/jobs')
diff --git a/submarine-sdk/pysubmarine/tests/utils/test_tf_utils.py
b/submarine-sdk/pysubmarine/tests/utils/test_tf_utils.py
index 736510e..52ed4ad 100644
--- a/submarine-sdk/pysubmarine/tests/utils/test_tf_utils.py
+++ b/submarine-sdk/pysubmarine/tests/utils/test_tf_utils.py
@@ -14,6 +14,7 @@
# limitations under the License.
import pytest
+
from submarine.utils.tf_utils import get_tf_config
@@ -23,11 +24,29 @@ def test_get_tf_config():
get_tf_config(params)
# conf for local training
- params.update({'training': {'mode': 'local', 'log_steps': 10},
- 'resource': {'num_cpu': 4, 'num_thread': 4, 'num_gpu': 1}})
+ params.update({
+ 'training': {
+ 'mode': 'local',
+ 'log_steps': 10
+ },
+ 'resource': {
+ 'num_cpu': 4,
+ 'num_thread': 4,
+ 'num_gpu': 1
+ }
+ })
get_tf_config(params)
# conf for distributed training
- params.update({'training': {'mode': 'distributed', 'log_steps': 10},
- 'resource': {'num_cpu': 4, 'num_thread': 4, 'num_gpu': 2}})
+ params.update({
+ 'training': {
+ 'mode': 'distributed',
+ 'log_steps': 10
+ },
+ 'resource': {
+ 'num_cpu': 4,
+ 'num_thread': 4,
+ 'num_gpu': 2
+ }
+ })
get_tf_config(params)
diff --git a/submarine-sdk/pysubmarine/tests/utils/test_validation.py
b/submarine-sdk/pysubmarine/tests/utils/test_validation.py
index 1382084..0030913 100644
--- a/submarine-sdk/pysubmarine/tests/utils/test_validation.py
+++ b/submarine-sdk/pysubmarine/tests/utils/test_validation.py
@@ -16,14 +16,35 @@
import pytest
from submarine.exceptions import SubmarineException
-from submarine.utils.validation import _validate_metric_name,
_validate_param_name,\
- _validate_length_limit, _validate_db_type_string
+from submarine.utils.validation import (_validate_db_type_string,
+ _validate_length_limit,
+ _validate_metric_name,
+ _validate_param_name)
GOOD_METRIC_OR_PARAM_NAMES = [
- "a", "Ab-5_", "a/b/c", "a.b.c", ".a", "b.", "a..a/._./o_O/.e.", "a b/c d",
+ "a",
+ "Ab-5_",
+ "a/b/c",
+ "a.b.c",
+ ".a",
+ "b.",
+ "a..a/._./o_O/.e.",
+ "a b/c d",
]
BAD_METRIC_OR_PARAM_NAMES = [
- "", ".", "/", "..", "//", "a//b", "a/./b", "/a", "a/", ":", "\\", "./",
"/./",
+ "",
+ ".",
+ "/",
+ "..",
+ "//",
+ "a//b",
+ "a/./b",
+ "/a",
+ "a/",
+ ":",
+ "\\",
+ "./",
+ "/./",
]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]