Skip to content

Commit

Permalink
feat(experiment): Abstract manager client
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <[email protected]>
  • Loading branch information
gaocegege committed May 20, 2019
1 parent 02e6afb commit 118961c
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 79 deletions.
11 changes: 7 additions & 4 deletions pkg/controller/v1alpha2/experiment/experiment_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2"
trialsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/consts"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/managerclient"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/manifest"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/suggestion"
suggestionfake "github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/suggestion/fake"
Expand All @@ -64,8 +65,9 @@ func Add(mgr manager.Manager) error {
// newReconciler returns a new reconcile.Reconciler
func newReconciler(mgr manager.Manager) reconcile.Reconciler {
r := &ReconcileExperiment{
Client: mgr.GetClient(),
scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
scheme: mgr.GetScheme(),
ManagerClient: managerclient.New(),
}
imp := viper.GetString(consts.ConfigExperimentSuggestionName)
r.Suggestion = newSuggestion(imp)
Expand Down Expand Up @@ -197,6 +199,7 @@ type ReconcileExperiment struct {

suggestion.Suggestion
manifest.Generator
managerclient.ManagerClient
// updateStatusHandler is defined for test purpose.
updateStatusHandler updateStatusFunc
}
Expand Down Expand Up @@ -243,7 +246,7 @@ func (r *ReconcileExperiment) Reconcile(request reconcile.Request) (reconcile.Re
msg := "Experiment is created"
instance.MarkExperimentStatusCreated(util.ExperimentCreatedReason, msg)

err = util.CreateExperimentInDB(instance)
err = r.CreateExperimentInDB(instance)
if err != nil {
logger.Error(err, "Create experiment in DB error")
return reconcile.Result{}, err
Expand All @@ -259,7 +262,7 @@ func (r *ReconcileExperiment) Reconcile(request reconcile.Request) (reconcile.Re

if !equality.Semantic.DeepEqual(original.Status, instance.Status) {
//assuming that only status change
err = util.UpdateExperimentStatusInDB(instance)
err = r.UpdateExperimentStatusInDB(instance)
if err != nil {
logger.Error(err, "Update experiment status in DB error")
return reconcile.Result{}, err
Expand Down
7 changes: 3 additions & 4 deletions pkg/controller/v1alpha2/experiment/experiment_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2"
trialsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/trial/v1alpha2"
apiv1alpha2 "github.com/kubeflow/katib/pkg/api/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/util"
)

func (r *ReconcileExperiment) createTrialInstance(expInstance *experimentsv1alpha2.Experiment, trialInstance *apiv1alpha2.Trial) error {
Expand All @@ -39,8 +38,8 @@ func (r *ReconcileExperiment) createTrialInstance(expInstance *experimentsv1alph
if trialInstance.Spec != nil && trialInstance.Spec.ParameterAssignments != nil {
for _, p := range trialInstance.Spec.ParameterAssignments.Assignments {
hps = append(hps, p)
pa := common.ParameterAssignment {
Name: p.Name,
pa := common.ParameterAssignment{
Name: p.Name,
Value: p.Value,
}
trial.Spec.ParameterAssignments = append(trial.Spec.ParameterAssignments, pa)
Expand Down Expand Up @@ -85,7 +84,7 @@ func (r *ReconcileExperiment) createTrialInstance(expInstance *experimentsv1alph
func (r *ReconcileExperiment) updateFinalizers(instance *experimentsv1alpha2.Experiment, finalizers []string) (reconcile.Result, error) {
logger := log.WithValues("Experiment", types.NamespacedName{Name: instance.Name, Namespace: instance.Namespace})
if instance.GetDeletionTimestamp() != nil {
if err := util.DeleteExperimentInDB(instance); err != nil {
if err := r.DeleteExperimentInDB(instance); err != nil {
logger.Error(err, "Fail to delete data in DB")
return reconcile.Result{}, err
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,71 +1,75 @@
/*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package util
package managerclient

import (
"database/sql"

commonv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2"
commonapiv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2"
experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2"
api_pb "github.com/kubeflow/katib/pkg/api/v1alpha2"
common "github.com/kubeflow/katib/pkg/common/v1alpha2"
commonv1alpha2 "github.com/kubeflow/katib/pkg/common/v1alpha2"
)

func CreateExperimentInDB(instance *experimentsv1alpha2.Experiment) error {
experiment := GetExperimentConf(instance)
// ManagerClient is the interface for katib manager client
// in experiment controller.
type ManagerClient interface {
CreateExperimentInDB(instance *experimentsv1alpha2.Experiment) error
DeleteExperimentInDB(instance *experimentsv1alpha2.Experiment) error
UpdateExperimentStatusInDB(instance *experimentsv1alpha2.Experiment) error
GetExperimentFromDB(instance *experimentsv1alpha2.Experiment) (
*api_pb.GetExperimentReply, error)
}

// DefaultClient implements the Client interface.
type DefaultClient struct {
}

// New creates a new ManagerClient.
func New() ManagerClient {
return &DefaultClient{}
}

func (d *DefaultClient) CreateExperimentInDB(instance *experimentsv1alpha2.Experiment) error {
experiment := getExperimentConf(instance)
request := &api_pb.RegisterExperimentRequest{
Experiment: experiment,
}
if _, err := common.RegisterExperiment(request); err != nil {
if _, err := commonv1alpha2.RegisterExperiment(request); err != nil {
return err
}
return nil
}

func DeleteExperimentInDB(instance *experimentsv1alpha2.Experiment) error {
func (d *DefaultClient) DeleteExperimentInDB(instance *experimentsv1alpha2.Experiment) error {
request := &api_pb.DeleteExperimentRequest{
ExperimentName: instance.Name,
}
if _, err := common.DeleteExperiment(request); err != nil {
if _, err := commonv1alpha2.DeleteExperiment(request); err != nil {
return err
}
return nil
}

func UpdateExperimentStatusInDB(instance *experimentsv1alpha2.Experiment) error {
func (d *DefaultClient) UpdateExperimentStatusInDB(instance *experimentsv1alpha2.Experiment) error {
newStatus := &api_pb.ExperimentStatus{
StartTime: common.ConvertTime2RFC3339(instance.Status.StartTime),
CompletionTime: common.ConvertTime2RFC3339(instance.Status.CompletionTime),
StartTime: commonv1alpha2.ConvertTime2RFC3339(instance.Status.StartTime),
CompletionTime: commonv1alpha2.ConvertTime2RFC3339(instance.Status.CompletionTime),
Condition: getCondition(instance),
}
request := &api_pb.UpdateExperimentStatusRequest{
NewStatus: newStatus,
ExperimentName: instance.Name,
}
if _, err := common.UpdateExperimentStatus(request); err != nil {
if _, err := commonv1alpha2.UpdateExperimentStatus(request); err != nil {
return err
}
return nil
}

func GetExperimentFromDB(instance *experimentsv1alpha2.Experiment) (*api_pb.GetExperimentReply, error) {
func (d *DefaultClient) GetExperimentFromDB(instance *experimentsv1alpha2.Experiment) (*api_pb.GetExperimentReply, error) {
return nil, sql.ErrNoRows
}

func GetExperimentConf(instance *experimentsv1alpha2.Experiment) *api_pb.Experiment {
func getExperimentConf(instance *experimentsv1alpha2.Experiment) *api_pb.Experiment {
experiment := &api_pb.Experiment{
Spec: &api_pb.ExperimentSpec{
Objective: &api_pb.ObjectiveSpec{
Expand All @@ -76,8 +80,8 @@ func GetExperimentConf(instance *experimentsv1alpha2.Experiment) *api_pb.Experim
},
},
Status: &api_pb.ExperimentStatus{
StartTime: common.ConvertTime2RFC3339(instance.Status.StartTime),
CompletionTime: common.ConvertTime2RFC3339(instance.Status.CompletionTime),
StartTime: commonv1alpha2.ConvertTime2RFC3339(instance.Status.StartTime),
CompletionTime: commonv1alpha2.ConvertTime2RFC3339(instance.Status.CompletionTime),
Condition: getCondition(instance),
},
}
Expand All @@ -86,9 +90,9 @@ func GetExperimentConf(instance *experimentsv1alpha2.Experiment) *api_pb.Experim

//Populate Objective
switch instance.Spec.Objective.Type {
case commonv1alpha2.ObjectiveTypeMaximize:
case commonapiv1alpha2.ObjectiveTypeMaximize:
experiment.Spec.Objective.Type = api_pb.ObjectiveType_MAXIMIZE
case commonv1alpha2.ObjectiveTypeMinimize:
case commonapiv1alpha2.ObjectiveTypeMinimize:
experiment.Spec.Objective.Type = api_pb.ObjectiveType_MINIMIZE
default:
experiment.Spec.Objective.Type = api_pb.ObjectiveType_UNKNOWN
Expand Down
4 changes: 3 additions & 1 deletion pkg/controller/v1alpha2/experiment/validation_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission/types"

experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/managerclient"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/manifest"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/validator"
)
Expand All @@ -43,8 +44,9 @@ func newExperimentValidator() (*experimentValidator, error) {
if err != nil {
return nil, err
}
mc := managerclient.New()
return &experimentValidator{
Validator: validator.New(p),
Validator: validator.New(p, mc),
}, nil
}

Expand Down
10 changes: 6 additions & 4 deletions pkg/controller/v1alpha2/experiment/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
commonapiv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2"
experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2"
commonv1alpha2 "github.com/kubeflow/katib/pkg/common/v1alpha2"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/managerclient"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/manifest"
"github.com/kubeflow/katib/pkg/controller/v1alpha2/experiment/util"
)

var log = logf.Log.WithName("experiment-controller")
Expand All @@ -25,11 +25,13 @@ type Validator interface {

type DefaultValidator struct {
manifest.Generator
managerclient.ManagerClient
}

func New(generator manifest.Generator) Validator {
func New(generator manifest.Generator, managerClient managerclient.ManagerClient) Validator {
return &DefaultValidator{
Generator: generator,
Generator: generator,
ManagerClient: managerClient,
}
}

Expand Down Expand Up @@ -138,7 +140,7 @@ func (g *DefaultValidator) validateSupportedJob(job *unstructured.Unstructured)
}

func (g *DefaultValidator) validateForCreate(inst *experimentsv1alpha2.Experiment) error {
if _, err := util.GetExperimentFromDB(inst); err != nil {
if _, err := g.GetExperimentFromDB(inst); err != nil {
if err != sql.ErrNoRows {
return fmt.Errorf("Fail to check record for the experiment in DB: %v", err)
}
Expand Down
28 changes: 21 additions & 7 deletions pkg/controller/v1alpha2/experiment/validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package validator

import (
"bytes"
"database/sql"
"testing"

"github.com/golang/mock/gomock"
Expand All @@ -10,7 +11,8 @@ import (

commonv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2"
experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2"
mockmanifest "github.com/kubeflow/katib/pkg/mock/v1alpha2/experiment/manifest"
managerclientmock "github.com/kubeflow/katib/pkg/mock/v1alpha2/experiment/managerclient"
manifestmock "github.com/kubeflow/katib/pkg/mock/v1alpha2/experiment/manifest"
)

func init() {
Expand All @@ -33,11 +35,15 @@ spec:

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockCtrl2 := gomock.NewController(t)
defer mockCtrl2.Finish()

p := mockmanifest.NewMockGenerator(mockCtrl)
g := New(p)
p := manifestmock.NewMockGenerator(mockCtrl)
mc := managerclientmock.NewMockManagerClient(mockCtrl2)
g := New(p, mc)

p.EXPECT().GetRunSpec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(trialTFJobTemplate, nil)
mc.EXPECT().GetExperimentFromDB(gomock.Any()).Return(nil, sql.ErrNoRows).AnyTimes()

instance := newFakeInstance()
if err := g.(*DefaultValidator).validateTrialTemplate(instance); err == nil {
Expand All @@ -54,11 +60,15 @@ metadata:

mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockCtrl2 := gomock.NewController(t)
defer mockCtrl2.Finish()

p := mockmanifest.NewMockGenerator(mockCtrl)
g := New(p)
p := manifestmock.NewMockGenerator(mockCtrl)
mc := managerclientmock.NewMockManagerClient(mockCtrl2)
g := New(p, mc)

p.EXPECT().GetRunSpec(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(trialJobTemplate, nil)
mc.EXPECT().GetExperimentFromDB(gomock.Any()).Return(nil, sql.ErrNoRows).AnyTimes()

instance := newFakeInstance()
if err := g.(*DefaultValidator).validateTrialTemplate(instance); err != nil {
Expand All @@ -69,9 +79,12 @@ metadata:
func TestValidateExperiment(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockCtrl2 := gomock.NewController(t)
defer mockCtrl2.Finish()

p := mockmanifest.NewMockGenerator(mockCtrl)
g := New(p)
p := manifestmock.NewMockGenerator(mockCtrl)
mc := managerclientmock.NewMockManagerClient(mockCtrl2)
g := New(p, mc)

trialJobTemplate := `apiVersion: "batch/v1"
kind: "Job"
Expand Down Expand Up @@ -100,6 +113,7 @@ spec:
println(b.String())
return &b
}(), nil).AnyTimes()
mc.EXPECT().GetExperimentFromDB(gomock.Any()).Return(nil, sql.ErrNoRows).AnyTimes()

tcs := []struct {
Instance *experimentsv1alpha2.Experiment
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/v1alpha2/trial/trial_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestCreateTFJobTrial(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mc := managerclientmock.NewMockClient(mockCtrl)
mc := managerclientmock.NewMockManagerClient(mockCtrl)
mc.EXPECT().CreateTrialInDB(gomock.Any()).Return(nil).AnyTimes()
mc.EXPECT().UpdateTrialStatusInDB(gomock.Any()).Return(nil).AnyTimes()

Expand Down Expand Up @@ -92,7 +92,7 @@ func TestReconcileTFJobTrial(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mc := managerclientmock.NewMockClient(mockCtrl)
mc := managerclientmock.NewMockManagerClient(mockCtrl)
mc.EXPECT().CreateTrialInDB(gomock.Any()).Return(nil).AnyTimes()
mc.EXPECT().UpdateTrialStatusInDB(gomock.Any()).Return(nil).AnyTimes()

Expand Down
Loading

0 comments on commit 118961c

Please sign in to comment.